Yash030 Claude Opus 4.7 commited on
Commit
0157ac7
·
1 Parent(s): 0c3f08f

Deploy claude-code-nvidia proxy to Hugging Face Spaces

Browse files

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +107 -0
  2. Dockerfile +26 -0
  3. api/__init__.py +17 -0
  4. api/__pycache__/__init__.cpython-314.pyc +0 -0
  5. api/__pycache__/app.cpython-314.pyc +0 -0
  6. api/__pycache__/command_utils.cpython-314.pyc +0 -0
  7. api/__pycache__/dependencies.cpython-314.pyc +0 -0
  8. api/__pycache__/detection.cpython-314.pyc +0 -0
  9. api/__pycache__/gateway_model_ids.cpython-314.pyc +0 -0
  10. api/__pycache__/model_router.cpython-314.pyc +0 -0
  11. api/__pycache__/optimization_handlers.cpython-314.pyc +0 -0
  12. api/__pycache__/routes.cpython-314.pyc +0 -0
  13. api/__pycache__/runtime.cpython-314.pyc +0 -0
  14. api/__pycache__/services.cpython-314.pyc +0 -0
  15. api/__pycache__/validation_log.cpython-314.pyc +0 -0
  16. api/app.py +175 -0
  17. api/command_utils.py +164 -0
  18. api/dependencies.py +144 -0
  19. api/detection.py +136 -0
  20. api/gateway_model_ids.py +54 -0
  21. api/model_router.py +261 -0
  22. api/models/__init__.py +45 -0
  23. api/models/__pycache__/__init__.cpython-314.pyc +0 -0
  24. api/models/__pycache__/anthropic.cpython-314.pyc +0 -0
  25. api/models/__pycache__/responses.cpython-314.pyc +0 -0
  26. api/models/anthropic.py +163 -0
  27. api/models/responses.py +56 -0
  28. api/optimization_handlers.py +154 -0
  29. api/routes.py +271 -0
  30. api/runtime.py +338 -0
  31. api/services.py +305 -0
  32. api/validation_log.py +48 -0
  33. api/web_server_tools.py +22 -0
  34. api/web_tools/__init__.py +17 -0
  35. api/web_tools/__pycache__/__init__.cpython-314.pyc +0 -0
  36. api/web_tools/__pycache__/constants.cpython-314.pyc +0 -0
  37. api/web_tools/__pycache__/egress.cpython-314.pyc +0 -0
  38. api/web_tools/__pycache__/parsers.cpython-314.pyc +0 -0
  39. api/web_tools/__pycache__/request.cpython-314.pyc +0 -0
  40. api/web_tools/__pycache__/streaming.cpython-314.pyc +0 -0
  41. api/web_tools/constants.py +15 -0
  42. api/web_tools/egress.py +99 -0
  43. api/web_tools/outbound.py +278 -0
  44. api/web_tools/parsers.py +104 -0
  45. api/web_tools/request.py +86 -0
  46. api/web_tools/streaming.py +206 -0
  47. cli/__init__.py +6 -0
  48. cli/entrypoints.py +60 -0
  49. cli/manager.py +163 -0
  50. cli/process_registry.py +74 -0
.env.example ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NVIDIA NIM Config
2
+ NVIDIA_NIM_API_KEY=""
3
+
4
+
5
+ # All Claude model requests are mapped to these models, plain model is fallback
6
+ # Format: provider_type/model/name
7
+ # Valid provider: "nvidia_nim"
8
+ MODEL_OPUS=
9
+ MODEL_SONNET=
10
+ MODEL_HAIKU=
11
+ MODEL="nvidia_nim/z-ai/glm4.7"
12
+
13
+
14
+ # Optional live smoke model overrides. Smoke runs for NVIDIA NIM.
15
+ FCC_SMOKE_MODEL_NVIDIA_NIM=
16
+
17
+
18
+ # Thinking output
19
+ # Per-Claude-model switches for provider reasoning requests and Claude thinking blocks.
20
+ # Blank per-model switches inherit ENABLE_MODEL_THINKING.
21
+ ENABLE_OPUS_THINKING=
22
+ ENABLE_SONNET_THINKING=
23
+ ENABLE_HAIKU_THINKING=
24
+ ENABLE_MODEL_THINKING=true
25
+
26
+
27
+ # Provider config
28
+ # Per-provider proxy support: http and socks5, example: "http://username:password@host:port"
29
+ NVIDIA_NIM_PROXY=""
30
+
31
+ PROVIDER_RATE_LIMIT=1
32
+ PROVIDER_RATE_WINDOW=3
33
+ PROVIDER_MAX_CONCURRENCY=5
34
+
35
+
36
+ # HTTP client timeouts (seconds) for provider API requests
37
+ HTTP_READ_TIMEOUT=300
38
+ HTTP_WRITE_TIMEOUT=10
39
+ HTTP_CONNECT_TIMEOUT=10
40
+
41
+
42
+ # Optional server API key (Anthropic-style)
43
+ ANTHROPIC_AUTH_TOKEN="freecc"
44
+
45
+
46
+ # Messaging Platform: "telegram" | "discord" | "none"
47
+ MESSAGING_PLATFORM="discord"
48
+ MESSAGING_RATE_LIMIT=1
49
+ MESSAGING_RATE_WINDOW=1
50
+
51
+
52
+ # Voice Note Transcription
53
+ VOICE_NOTE_ENABLED=false
54
+ # WHISPER_DEVICE: "cpu" | "cuda" | "nvidia_nim"
55
+ # - "cpu"/"cuda": Hugging Face transformers Whisper (offline, free; install with: uv sync --extra voice_local)
56
+ # - "nvidia_nim": NVIDIA NIM Whisper via Riva gRPC (requires NVIDIA_NIM_API_KEY; install with: uv sync --extra voice)
57
+ # (Independent of MODEL=nvidia_nim/...: that selects the *chat* provider; this selects voice STT only.)
58
+ WHISPER_DEVICE="nvidia_nim"
59
+ # WHISPER_MODEL:
60
+ # - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
61
+ # - For nvidia_nim: NVIDIA NIM model (e.g., "nvidia/parakeet-ctc-1.1b-asr", "openai/whisper-large-v3")
62
+ # - For nvidia_nim, default to "openai/whisper-large-v3" for best performance
63
+ WHISPER_MODEL="openai/whisper-large-v3"
64
+ HF_TOKEN=""
65
+
66
+
67
+ # Telegram Config
68
+ TELEGRAM_BOT_TOKEN=""
69
+ ALLOWED_TELEGRAM_USER_ID=""
70
+
71
+
72
+ # Discord Config
73
+ DISCORD_BOT_TOKEN=""
74
+ ALLOWED_DISCORD_CHANNELS=""
75
+
76
+
77
+ # Agent Config
78
+ CLAUDE_WORKSPACE="./agent_workspace"
79
+ ALLOWED_DIR=""
80
+ CLAUDE_CLI_BIN="claude"
81
+ FAST_PREFIX_DETECTION=true
82
+ ENABLE_NETWORK_PROBE_MOCK=true
83
+ ENABLE_TITLE_GENERATION_SKIP=true
84
+ ENABLE_SUGGESTION_MODE_SKIP=true
85
+ ENABLE_FILEPATH_EXTRACTION_MOCK=true
86
+
87
+
88
+ # Local Anthropic web_search / web_fetch handling (performs outbound HTTP; on by default)
89
+ ENABLE_WEB_SERVER_TOOLS=true
90
+ WEB_FETCH_ALLOWED_SCHEMES=http,https
91
+ WEB_FETCH_ALLOW_PRIVATE_NETWORKS=false
92
+
93
+
94
+ # Verbose diagnostics (avoid logging raw prompts / SSE bodies in production)
95
+ DEBUG_PLATFORM_EDITS=false
96
+ DEBUG_SUBAGENT_STACK=false
97
+ # When true, also allows DEBUG-level httpx/httpcore/telegram log noise (not just payload logging).
98
+ LOG_RAW_API_PAYLOADS=false
99
+ LOG_RAW_SSE_EVENTS=false
100
+ # When true, log full exception text and tracebacks for unhandled errors (may leak request-derived data).
101
+ LOG_API_ERROR_TRACEBACKS=false
102
+ # When true, log message/transcription text previews in messaging adapters (may leak user content).
103
+ LOG_RAW_MESSAGING_CONTENT=false
104
+ # When true, log full Claude CLI stderr, non-JSON stdout lines, and parser error text.
105
+ LOG_RAW_CLI_DIAGNOSTICS=false
106
+ # When true, log full exception and CLI error message strings in messaging (may leak user content).
107
+ LOG_MESSAGING_ERROR_DETAILS=false
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.14-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install uv
6
+ RUN pip install uv
7
+
8
+ # Copy project files
9
+ COPY pyproject.toml uv.lock ./
10
+ COPY api/ ./api/
11
+ COPY cli/ ./cli/
12
+ COPY config/ ./config/
13
+ COPY core/ ./core/
14
+ COPY messaging/ ./messaging/
15
+ COPY providers/ ./providers/
16
+ COPY server.py ./
17
+ COPY .env.example ./
18
+
19
+ # Install dependencies
20
+ RUN uv sync --frozen --no-dev
21
+
22
+ # Expose port (HF Spaces default)
23
+ EXPOSE 7860
24
+
25
+ # Run server
26
+ CMD ["uv", "run", "uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
api/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API layer for Claude Code Proxy."""
2
+
3
+ from .app import create_app
4
+ from .models import (
5
+ MessagesRequest,
6
+ MessagesResponse,
7
+ TokenCountRequest,
8
+ TokenCountResponse,
9
+ )
10
+
11
+ __all__ = [
12
+ "MessagesRequest",
13
+ "MessagesResponse",
14
+ "TokenCountRequest",
15
+ "TokenCountResponse",
16
+ "create_app",
17
+ ]
api/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (431 Bytes). View file
 
api/__pycache__/app.cpython-314.pyc ADDED
Binary file (10.7 kB). View file
 
api/__pycache__/command_utils.cpython-314.pyc ADDED
Binary file (5.83 kB). View file
 
api/__pycache__/dependencies.cpython-314.pyc ADDED
Binary file (7.56 kB). View file
 
api/__pycache__/detection.cpython-314.pyc ADDED
Binary file (6.71 kB). View file
 
api/__pycache__/gateway_model_ids.cpython-314.pyc ADDED
Binary file (2.55 kB). View file
 
api/__pycache__/model_router.cpython-314.pyc ADDED
Binary file (12.6 kB). View file
 
api/__pycache__/optimization_handlers.cpython-314.pyc ADDED
Binary file (5.7 kB). View file
 
api/__pycache__/routes.cpython-314.pyc ADDED
Binary file (14.1 kB). View file
 
api/__pycache__/runtime.cpython-314.pyc ADDED
Binary file (20.1 kB). View file
 
api/__pycache__/services.cpython-314.pyc ADDED
Binary file (14 kB). View file
 
api/__pycache__/validation_log.cpython-314.pyc ADDED
Binary file (2.95 kB). View file
 
api/app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application factory and configuration."""
2
+
3
+ import traceback
4
+ from contextlib import asynccontextmanager
5
+ from typing import Any
6
+
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.exception_handlers import request_validation_exception_handler
9
+ from fastapi.exceptions import RequestValidationError
10
+ from fastapi.responses import JSONResponse
11
+ from loguru import logger
12
+ from starlette.types import Receive, Scope, Send
13
+
14
+ from config.logging_config import configure_logging
15
+ from config.settings import get_settings
16
+ from providers.exceptions import ProviderError
17
+
18
+ from .routes import router
19
+ from .runtime import AppRuntime, startup_failure_message
20
+ from .validation_log import summarize_request_validation_body
21
+
22
+
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ """Application lifespan manager."""
26
+ runtime = AppRuntime.for_app(app, settings=get_settings())
27
+ await runtime.startup()
28
+
29
+ yield
30
+
31
+ await runtime.shutdown()
32
+
33
+
34
+ class GracefulLifespanApp:
35
+ """ASGI wrapper that reports startup failures without Starlette tracebacks."""
36
+
37
+ def __init__(self, app: FastAPI):
38
+ self.app = app
39
+
40
+ def __getattr__(self, name: str) -> Any:
41
+ return getattr(self.app, name)
42
+
43
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
44
+ if scope["type"] != "lifespan":
45
+ await self.app(scope, receive, send)
46
+ return
47
+ await self._lifespan(receive, send)
48
+
49
+ async def _lifespan(self, receive: Receive, send: Send) -> None:
50
+ settings = get_settings()
51
+ runtime = AppRuntime.for_app(self.app, settings=settings)
52
+ startup_complete = False
53
+ while True:
54
+ message = await receive()
55
+ if message["type"] == "lifespan.startup":
56
+ try:
57
+ await runtime.startup()
58
+ except Exception as exc:
59
+ await send(
60
+ {
61
+ "type": "lifespan.startup.failed",
62
+ "message": startup_failure_message(settings, exc),
63
+ }
64
+ )
65
+ return
66
+ startup_complete = True
67
+ await send({"type": "lifespan.startup.complete"})
68
+ continue
69
+
70
+ if message["type"] == "lifespan.shutdown":
71
+ if startup_complete:
72
+ try:
73
+ await runtime.shutdown()
74
+ except Exception as exc:
75
+ logger.error("Shutdown failed: exc_type={}", type(exc).__name__)
76
+ await send({"type": "lifespan.shutdown.failed", "message": ""})
77
+ return
78
+ await send({"type": "lifespan.shutdown.complete"})
79
+ return
80
+
81
+
82
+ def create_app(*, lifespan_enabled: bool = True) -> FastAPI:
83
+ """Create and configure the FastAPI application."""
84
+ settings = get_settings()
85
+ configure_logging(
86
+ settings.log_file, verbose_third_party=settings.log_raw_api_payloads
87
+ )
88
+
89
+ app_kwargs: dict[str, Any] = {
90
+ "title": "Claude Code Proxy",
91
+ "version": "2.0.0",
92
+ }
93
+ if lifespan_enabled:
94
+ app_kwargs["lifespan"] = lifespan
95
+ app = FastAPI(**app_kwargs)
96
+
97
+ # Register routes
98
+ app.include_router(router)
99
+
100
+ # Exception handlers
101
+ @app.exception_handler(RequestValidationError)
102
+ async def validation_error_handler(request: Request, exc: RequestValidationError):
103
+ """Log request shape for 422 debugging without content values."""
104
+ body: Any
105
+ try:
106
+ body = await request.json()
107
+ except Exception as e:
108
+ body = {"_json_error": type(e).__name__}
109
+
110
+ message_summary, tool_names = summarize_request_validation_body(body)
111
+
112
+ logger.debug(
113
+ "Request validation failed: path={} query={} error_locs={} error_types={} message_summary={} tool_names={}",
114
+ request.url.path,
115
+ str(request.url.query),
116
+ [list(error.get("loc", ())) for error in exc.errors()],
117
+ [str(error.get("type", "")) for error in exc.errors()],
118
+ message_summary,
119
+ tool_names,
120
+ )
121
+ return await request_validation_exception_handler(request, exc)
122
+
123
+ @app.exception_handler(ProviderError)
124
+ async def provider_error_handler(request: Request, exc: ProviderError):
125
+ """Handle provider-specific errors and return Anthropic format."""
126
+ err_settings = get_settings()
127
+ if err_settings.log_api_error_tracebacks:
128
+ logger.error(
129
+ "Provider Error: error_type={} status_code={} message={}",
130
+ exc.error_type,
131
+ exc.status_code,
132
+ exc.message,
133
+ )
134
+ else:
135
+ logger.error(
136
+ "Provider Error: error_type={} status_code={}",
137
+ exc.error_type,
138
+ exc.status_code,
139
+ )
140
+ return JSONResponse(
141
+ status_code=exc.status_code,
142
+ content=exc.to_anthropic_format(),
143
+ )
144
+
145
+ @app.exception_handler(Exception)
146
+ async def general_error_handler(request: Request, exc: Exception):
147
+ """Handle general errors and return Anthropic format."""
148
+ settings = get_settings()
149
+ if settings.log_api_error_tracebacks:
150
+ logger.error("General Error: {}", exc)
151
+ logger.error(traceback.format_exc())
152
+ else:
153
+ logger.error(
154
+ "General Error: path={} method={} exc_type={}",
155
+ request.url.path,
156
+ request.method,
157
+ type(exc).__name__,
158
+ )
159
+ return JSONResponse(
160
+ status_code=500,
161
+ content={
162
+ "type": "error",
163
+ "error": {
164
+ "type": "api_error",
165
+ "message": "An unexpected error occurred.",
166
+ },
167
+ },
168
+ )
169
+
170
+ return app
171
+
172
+
173
+ def create_asgi_app() -> GracefulLifespanApp:
174
+ """Create the server ASGI app with graceful lifespan failure reporting."""
175
+ return GracefulLifespanApp(create_app(lifespan_enabled=False))
api/command_utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command parsing utilities for API optimizations."""
2
+
3
+ import re
4
+ import shlex
5
+
6
+ _ENV_ASSIGNMENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*=.*$")
7
+
8
+
9
+ def _is_env_assignment(part: str) -> bool:
10
+ """Return True when a token is a shell-style env assignment."""
11
+ return bool(_ENV_ASSIGNMENT_RE.match(part))
12
+
13
+
14
+ def _strip_env_assignments(parts: list[str]) -> list[str]:
15
+ """Return command parts after leading shell-style env assignments."""
16
+ cmd_start = 0
17
+ for i, part in enumerate(parts):
18
+ if _is_env_assignment(part):
19
+ cmd_start = i + 1
20
+ else:
21
+ break
22
+ return parts[cmd_start:]
23
+
24
+
25
+ def extract_command_prefix(command: str) -> str:
26
+ """Extract the command prefix for fast prefix detection.
27
+
28
+ Parses a shell command safely, handling environment variables and
29
+ command injection attempts. Returns the command prefix suitable
30
+ for quick identification.
31
+
32
+ Returns:
33
+ Command prefix (e.g., "git", "git commit", "npm install")
34
+ or "none" if no valid command found
35
+ """
36
+ if "`" in command or "$(" in command:
37
+ return "command_injection_detected"
38
+
39
+ try:
40
+ parts = shlex.split(command, posix=False)
41
+ if not parts:
42
+ return "none"
43
+
44
+ env_prefix = []
45
+ cmd_start = 0
46
+ for i, part in enumerate(parts):
47
+ if _is_env_assignment(part):
48
+ env_prefix.append(part)
49
+ cmd_start = i + 1
50
+ else:
51
+ break
52
+
53
+ if cmd_start >= len(parts):
54
+ return "none"
55
+
56
+ cmd_parts = parts[cmd_start:]
57
+ if not cmd_parts:
58
+ return "none"
59
+
60
+ first_word = cmd_parts[0]
61
+ two_word_commands = {
62
+ "git",
63
+ "npm",
64
+ "docker",
65
+ "kubectl",
66
+ "cargo",
67
+ "go",
68
+ "pip",
69
+ "yarn",
70
+ }
71
+
72
+ if first_word in two_word_commands and len(cmd_parts) > 1:
73
+ second_word = cmd_parts[1]
74
+ if not second_word.startswith("-"):
75
+ return f"{first_word} {second_word}"
76
+ return first_word
77
+ return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
78
+
79
+ except ValueError:
80
+ parts = command.split()
81
+ if not parts:
82
+ return "none"
83
+ cmd_parts = _strip_env_assignments(parts)
84
+ return cmd_parts[0] if cmd_parts else "none"
85
+
86
+
87
+ def extract_filepaths_from_command(command: str, output: str) -> str:
88
+ """Extract file paths from a command locally without API call.
89
+
90
+ Determines if the command reads file contents and extracts paths accordingly.
91
+ Commands like ls/dir/find just list files, so return empty.
92
+ Commands like cat/head/tail actually read contents, so extract the file path.
93
+
94
+ Returns:
95
+ Filepath extraction result in <filepaths> format
96
+ """
97
+ listing_commands = {
98
+ "ls",
99
+ "dir",
100
+ "find",
101
+ "tree",
102
+ "pwd",
103
+ "cd",
104
+ "mkdir",
105
+ "rmdir",
106
+ "rm",
107
+ }
108
+
109
+ reading_commands = {"cat", "head", "tail", "less", "more", "bat", "type"}
110
+
111
+ try:
112
+ parts = shlex.split(command, posix=False)
113
+ if not parts:
114
+ return "<filepaths>\n</filepaths>"
115
+
116
+ cmd_parts = _strip_env_assignments(parts)
117
+ if not cmd_parts:
118
+ return "<filepaths>\n</filepaths>"
119
+
120
+ base_cmd = cmd_parts[0].split("/")[-1].split("\\")[-1].lower()
121
+
122
+ if base_cmd in listing_commands:
123
+ return "<filepaths>\n</filepaths>"
124
+
125
+ if base_cmd in reading_commands:
126
+ filepaths = []
127
+ for part in cmd_parts[1:]:
128
+ if part.startswith("-"):
129
+ continue
130
+ filepaths.append(part)
131
+
132
+ if filepaths:
133
+ paths_str = "\n".join(filepaths)
134
+ return f"<filepaths>\n{paths_str}\n</filepaths>"
135
+ return "<filepaths>\n</filepaths>"
136
+
137
+ if base_cmd == "grep":
138
+ flags_with_args = {"-e", "-f", "-m", "-A", "-B", "-C"}
139
+ pattern_provided_via_flag = False
140
+ positional = []
141
+
142
+ skip_next = False
143
+ for part in cmd_parts[1:]:
144
+ if skip_next:
145
+ skip_next = False
146
+ continue
147
+ if part.startswith("-"):
148
+ if part in flags_with_args:
149
+ if part in {"-e", "-f"}:
150
+ pattern_provided_via_flag = True
151
+ skip_next = True
152
+ continue
153
+ positional.append(part)
154
+
155
+ filepaths = positional if pattern_provided_via_flag else positional[1:]
156
+ if filepaths:
157
+ paths_str = "\n".join(filepaths)
158
+ return f"<filepaths>\n{paths_str}\n</filepaths>"
159
+ return "<filepaths>\n</filepaths>"
160
+
161
+ return "<filepaths>\n</filepaths>"
162
+
163
+ except ValueError:
164
+ return "<filepaths>\n</filepaths>"
api/dependencies.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dependency injection for FastAPI."""
2
+
3
+ import secrets
4
+
5
+ from fastapi import Depends, HTTPException, Request
6
+ from loguru import logger
7
+ from starlette.applications import Starlette
8
+
9
+ from config.settings import Settings
10
+ from config.settings import get_settings as _get_settings
11
+ from core.anthropic import get_user_facing_error_message
12
+ from providers.base import BaseProvider
13
+ from providers.exceptions import (
14
+ AuthenticationError,
15
+ ServiceUnavailableError,
16
+ UnknownProviderTypeError,
17
+ )
18
+ from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
19
+
20
+ # Process-level cache: only for :func:`get_provider_for_type` / :func:`get_provider`
21
+ # when there is no ``Request``/``app`` (unit tests, scripts). HTTP handlers must pass
22
+ # ``app`` to :func:`resolve_provider` so the app-scoped registry is used.
23
+ _providers: dict[str, BaseProvider] = {}
24
+
25
+
26
+ def get_settings() -> Settings:
27
+ """Return cached :class:`~config.settings.Settings` (FastAPI-friendly alias)."""
28
+ return _get_settings()
29
+
30
+
31
+ def resolve_provider(
32
+ provider_type: str,
33
+ *,
34
+ app: Starlette | None,
35
+ settings: Settings,
36
+ ) -> BaseProvider:
37
+ """Resolve a provider using the app-scoped registry when ``app`` is set.
38
+
39
+ When ``app`` is not ``None``, the app-owned :attr:`app.state.provider_registry`
40
+ must exist (installed by :class:`~api.runtime.AppRuntime` during startup).
41
+ Callers that construct a bare ``FastAPI`` without lifespan must set
42
+ ``app.state.provider_registry`` explicitly.
43
+
44
+ When ``app`` is ``None`` (no HTTP context), uses the process-level
45
+ :data:`_providers` cache only.
46
+ """
47
+ if app is not None:
48
+ reg = getattr(app.state, "provider_registry", None)
49
+ if reg is None:
50
+ raise ServiceUnavailableError(
51
+ "Provider registry is not configured. Ensure AppRuntime startup ran "
52
+ "or assign app.state.provider_registry for test apps."
53
+ )
54
+ return _resolve_with_registry(reg, provider_type, settings)
55
+ return _resolve_with_registry(ProviderRegistry(_providers), provider_type, settings)
56
+
57
+
58
+ def _resolve_with_registry(
59
+ registry: ProviderRegistry, provider_type: str, settings: Settings
60
+ ) -> BaseProvider:
61
+ should_log_init = not registry.is_cached(provider_type)
62
+ try:
63
+ provider = registry.get(provider_type, settings)
64
+ except AuthenticationError as e:
65
+ # Provider :class:`~providers.exceptions.AuthenticationError` messages are
66
+ # curated configuration hints (env var names, docs links), not upstream noise.
67
+ detail = str(e).strip() or get_user_facing_error_message(e)
68
+ raise HTTPException(status_code=503, detail=detail) from e
69
+ except UnknownProviderTypeError:
70
+ logger.error(
71
+ "Unknown provider_type: '{}'. Supported: {}",
72
+ provider_type,
73
+ ", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS),
74
+ )
75
+ raise
76
+ if should_log_init:
77
+ logger.info("Provider initialized: {}", provider_type)
78
+ return provider
79
+
80
+
81
+ def get_provider_for_type(provider_type: str) -> BaseProvider:
82
+ """Get or create a provider in the process-level cache (no ``app``/Request).
83
+
84
+ HTTP route handlers should call :func:`resolve_provider` with the active
85
+ :attr:`request.app` (via :class:`~api.runtime.AppRuntime`) instead of this
86
+ process-wide cache.
87
+ """
88
+ return resolve_provider(provider_type, app=None, settings=get_settings())
89
+
90
+
91
+ def require_api_key(
92
+ request: Request, settings: Settings = Depends(get_settings)
93
+ ) -> None:
94
+ """Require a server API key (Anthropic-style).
95
+
96
+ Checks `x-api-key` header or `Authorization: Bearer ...` against
97
+ `Settings.anthropic_auth_token`. If `ANTHROPIC_AUTH_TOKEN` is empty, this is a no-op.
98
+ """
99
+ anthropic_auth_token = settings.anthropic_auth_token
100
+ if not anthropic_auth_token:
101
+ # No API key configured -> allow
102
+ return
103
+
104
+ header = (
105
+ request.headers.get("x-api-key")
106
+ or request.headers.get("authorization")
107
+ or request.headers.get("anthropic-auth-token")
108
+ )
109
+ if not header:
110
+ raise HTTPException(status_code=401, detail="Missing API key")
111
+
112
+ # Support both raw key in X-API-Key and Bearer token in Authorization
113
+ token = header
114
+ if header.lower().startswith("bearer "):
115
+ token = header.split(" ", 1)[1]
116
+
117
+ # Strip anything after the first colon to handle tokens with appended model names
118
+ if token and ":" in token:
119
+ token = token.split(":", 1)[0]
120
+
121
+ # Constant-time comparison to avoid leaking the configured token via
122
+ # response-time differences on a per-byte mismatch (CWE-208).
123
+ if not secrets.compare_digest(
124
+ token.encode("utf-8"), anthropic_auth_token.encode("utf-8")
125
+ ):
126
+ raise HTTPException(status_code=401, detail="Invalid API key")
127
+
128
+
129
+ def get_provider() -> BaseProvider:
130
+ """Get or create the default provider (``MODEL`` / ``provider_type``).
131
+
132
+ Process-cache helper for scripts, unit tests, and non-FastAPI callers. HTTP
133
+ handlers must use :func:`resolve_provider` with :attr:`request.app` so the
134
+ app-scoped :class:`~providers.registry.ProviderRegistry` is used.
135
+ """
136
+ return get_provider_for_type(get_settings().provider_type)
137
+
138
+
139
+ async def cleanup_provider():
140
+ """Cleanup all provider resources."""
141
+ global _providers
142
+ await ProviderRegistry(_providers).cleanup()
143
+ _providers = {}
144
+ logger.debug("Provider cleanup completed")
api/detection.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Request detection utilities for API optimizations.
2
+
3
+ Detects quota checks, title generation, prefix detection, suggestion mode,
4
+ and filepath extraction requests to enable fast-path responses.
5
+ """
6
+
7
+ from core.anthropic import extract_text_from_content
8
+
9
+ from .models.anthropic import MessagesRequest
10
+
11
+
12
+ def is_quota_check_request(request_data: MessagesRequest) -> bool:
13
+ """Check if this is a quota probe request.
14
+
15
+ Quota checks are typically simple requests with max_tokens=1
16
+ and a single message containing the word "quota".
17
+ """
18
+ if (
19
+ request_data.max_tokens == 1
20
+ and len(request_data.messages) == 1
21
+ and request_data.messages[0].role == "user"
22
+ ):
23
+ text = extract_text_from_content(request_data.messages[0].content)
24
+ if "quota" in text.lower():
25
+ return True
26
+ return False
27
+
28
+
29
+ def is_title_generation_request(request_data: MessagesRequest) -> bool:
30
+ """Check if this is a conversation title generation request.
31
+
32
+ Title generation requests are detected by a system prompt containing
33
+ title extraction instructions, no tools, and a single user message.
34
+
35
+ Matches Claude Code session title prompts (sentence-case title, JSON
36
+ \"title\" field, etc.).
37
+ """
38
+ if not request_data.system or request_data.tools:
39
+ return False
40
+ system_text = extract_text_from_content(request_data.system).lower()
41
+ if "title" not in system_text:
42
+ return False
43
+ return "sentence-case title" in system_text or (
44
+ "return json" in system_text
45
+ and "field" in system_text
46
+ and ("coding session" in system_text or "this session" in system_text)
47
+ )
48
+
49
+
50
+ def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
51
+ """Check if this is a fast prefix detection request.
52
+
53
+ Prefix detection requests contain a policy_spec block and
54
+ a Command: section for extracting shell command prefixes.
55
+
56
+ Returns:
57
+ Tuple of (is_prefix_request, command_string)
58
+ """
59
+ if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
60
+ return False, ""
61
+
62
+ content = extract_text_from_content(request_data.messages[0].content)
63
+
64
+ if "<policy_spec>" in content and "Command:" in content:
65
+ try:
66
+ cmd_start = content.rfind("Command:") + len("Command:")
67
+ return True, content[cmd_start:].strip()
68
+ except TypeError:
69
+ return False, ""
70
+
71
+ return False, ""
72
+
73
+
74
+ def is_suggestion_mode_request(request_data: MessagesRequest) -> bool:
75
+ """Check if this is a suggestion mode request.
76
+
77
+ Suggestion mode requests contain "[SUGGESTION MODE:" in the user's message,
78
+ used for auto-suggesting what the user might type next.
79
+ """
80
+ for msg in request_data.messages:
81
+ if msg.role == "user":
82
+ text = extract_text_from_content(msg.content)
83
+ if "[SUGGESTION MODE:" in text:
84
+ return True
85
+ return False
86
+
87
+
88
+ def is_filepath_extraction_request(
89
+ request_data: MessagesRequest,
90
+ ) -> tuple[bool, str, str]:
91
+ """Check if this is a filepath extraction request.
92
+
93
+ Filepath extraction requests have a single user message with
94
+ "Command:" and "Output:" sections, asking to extract file paths
95
+ from command output.
96
+
97
+ Returns:
98
+ Tuple of (is_filepath_request, command, output)
99
+ """
100
+ if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
101
+ return False, "", ""
102
+ if request_data.tools:
103
+ return False, "", ""
104
+
105
+ content = extract_text_from_content(request_data.messages[0].content)
106
+
107
+ if "Command:" not in content or "Output:" not in content:
108
+ return False, "", ""
109
+
110
+ # Match if user content OR system block indicates filepath extraction
111
+ user_has_filepaths = (
112
+ "filepaths" in content.lower() or "<filepaths>" in content.lower()
113
+ )
114
+ system_text = (
115
+ extract_text_from_content(request_data.system) if request_data.system else ""
116
+ )
117
+ system_has_extract = (
118
+ "extract any file paths" in system_text.lower()
119
+ or "file paths that this command" in system_text.lower()
120
+ )
121
+ if not user_has_filepaths and not system_has_extract:
122
+ return False, "", ""
123
+
124
+ cmd_start = content.find("Command:") + len("Command:")
125
+ output_marker = content.find("Output:", cmd_start)
126
+ if output_marker == -1:
127
+ return False, "", ""
128
+
129
+ command = content[cmd_start:output_marker].strip()
130
+ output = content[output_marker + len("Output:") :].strip()
131
+
132
+ for marker in ["<", "\n\n"]:
133
+ if marker in output:
134
+ output = output.split(marker)[0].strip()
135
+
136
+ return True, command, output
api/gateway_model_ids.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gateway-safe model id encoding for Claude Code model discovery."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ GATEWAY_MODEL_ID_PREFIX = "anthropic"
8
+
9
+ # Claude Code currently treats any model id containing ``claude-3-`` as not
10
+ # supporting thinking. This intentionally uses that client-side capability
11
+ # heuristic while keeping the real provider/model ref reversible for routing.
12
+ NO_THINKING_GATEWAY_MODEL_ID_PREFIX = "claude-3-freecc-no-thinking"
13
+
14
+
15
+ @dataclass(frozen=True, slots=True)
16
+ class DecodedGatewayModelId:
17
+ provider_id: str
18
+ provider_model: str
19
+ force_thinking_enabled: bool | None = None
20
+
21
+
22
+ def gateway_model_id(provider_model_ref: str) -> str:
23
+ """Return the normal Claude Code-discoverable id for a provider/model ref."""
24
+ return f"{GATEWAY_MODEL_ID_PREFIX}/{provider_model_ref}"
25
+
26
+
27
+ def no_thinking_gateway_model_id(provider_model_ref: str) -> str:
28
+ """Return a Claude Code-discoverable id that disables client thinking."""
29
+ return f"{NO_THINKING_GATEWAY_MODEL_ID_PREFIX}/{provider_model_ref}"
30
+
31
+
32
+ def decode_gateway_model_id(model_name: str) -> DecodedGatewayModelId | None:
33
+ """Decode a model id advertised by this gateway, if it is one."""
34
+ prefix, separator, remainder = model_name.partition("/")
35
+ if not separator:
36
+ return None
37
+
38
+ force_thinking_enabled: bool | None
39
+ if prefix == GATEWAY_MODEL_ID_PREFIX:
40
+ force_thinking_enabled = None
41
+ elif prefix == NO_THINKING_GATEWAY_MODEL_ID_PREFIX:
42
+ force_thinking_enabled = False
43
+ else:
44
+ return None
45
+
46
+ provider_id, provider_separator, provider_model = remainder.partition("/")
47
+ if not provider_separator or not provider_model:
48
+ return None
49
+
50
+ return DecodedGatewayModelId(
51
+ provider_id=provider_id,
52
+ provider_model=provider_model,
53
+ force_thinking_enabled=force_thinking_enabled,
54
+ )
api/model_router.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model routing for Claude-compatible requests."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from loguru import logger
8
+
9
+ from config.provider_ids import SUPPORTED_PROVIDER_IDS
10
+ from config.settings import Settings
11
+
12
+ from .gateway_model_ids import decode_gateway_model_id
13
+ from .models.anthropic import MessagesRequest, TokenCountRequest
14
+ from providers.rate_limit import GlobalRateLimiter
15
+
16
+
17
+ @dataclass(frozen=True, slots=True)
18
+ class ResolvedModel:
19
+ original_model: str
20
+ provider_id: str
21
+ provider_model: str
22
+ provider_model_ref: str
23
+ thinking_enabled: bool
24
+
25
+
26
+ @dataclass(frozen=True, slots=True)
27
+ class RoutedMessagesRequest:
28
+ request: MessagesRequest
29
+ resolved: ResolvedModel
30
+
31
+
32
+ @dataclass(frozen=True, slots=True)
33
+ class RoutedTokenCountRequest:
34
+ request: TokenCountRequest
35
+ resolved: ResolvedModel
36
+
37
+
38
+ class ModelRouter:
39
+ """Resolve incoming Claude model names to configured provider/model pairs."""
40
+
41
+ def __init__(self, settings: Settings):
42
+ self._settings = settings
43
+
44
+ def _is_auto(self, model_name: str) -> bool:
45
+ """Return whether the model name refers to the virtual 'auto' model."""
46
+ name_lower = model_name.lower()
47
+ return name_lower == "auto" or name_lower == "anthropic/auto"
48
+
49
+ def _normalize_candidate_ref(self, raw_ref: str) -> str | None:
50
+ """Normalize auto candidate refs to ``provider/model`` when possible."""
51
+ candidate = (raw_ref or "").strip()
52
+ if not candidate:
53
+ return None
54
+
55
+ provider_id, separator, remainder = candidate.partition("/")
56
+ if separator and provider_id in SUPPORTED_PROVIDER_IDS and remainder:
57
+ return f"{provider_id}/{remainder}"
58
+
59
+ # Treat bare model ids and vendor/model ids as NVIDIA NIM models.
60
+ return f"nvidia_nim/{candidate}"
61
+
62
+ def resolve(self, claude_model_name: str) -> ResolvedModel:
63
+ # Special virtual model 'auto' maps to the configured default MODEL and
64
+ # enables provider-side fallbacks. Resolve it to the configured model
65
+ # while preserving the original requested name.
66
+ if self._is_auto(claude_model_name):
67
+ # If the user configured an explicit AUTO_MODEL_ORDER, try each
68
+ # provider/model pair in order and pick the first provider that is
69
+ # plausibly configured. Fall back to the single configured MODEL.
70
+ order_csv = (self._settings.auto_model_order or "").strip()
71
+ if order_csv:
72
+ for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
73
+ if "/" not in cand:
74
+ # assume vendor-prefixed entries; skip malformed
75
+ continue
76
+ provider_id = Settings.parse_provider_type(cand)
77
+ provider_model = Settings.parse_model_name(cand)
78
+ if self._settings.provider_is_configured(provider_id):
79
+ thinking_enabled = self._settings.resolve_thinking(claude_model_name)
80
+ return ResolvedModel(
81
+ original_model=claude_model_name,
82
+ provider_id=provider_id,
83
+ provider_model=provider_model,
84
+ provider_model_ref=cand,
85
+ thinking_enabled=thinking_enabled,
86
+ )
87
+ # No explicit order matched or none configured — fall back to default MODEL
88
+ provider_model_ref = self._settings.model
89
+ provider_id = Settings.parse_provider_type(provider_model_ref)
90
+ provider_model = Settings.parse_model_name(provider_model_ref)
91
+ thinking_enabled = self._settings.resolve_thinking(claude_model_name)
92
+ return ResolvedModel(
93
+ original_model=claude_model_name,
94
+ provider_id=provider_id,
95
+ provider_model=provider_model,
96
+ provider_model_ref=provider_model_ref,
97
+ thinking_enabled=thinking_enabled,
98
+ )
99
+
100
+ (
101
+ direct_provider_id,
102
+ direct_provider_model,
103
+ force_thinking_enabled,
104
+ ) = self._direct_provider_model(claude_model_name)
105
+ if direct_provider_id is not None and direct_provider_model is not None:
106
+ thinking_enabled = (
107
+ force_thinking_enabled
108
+ if force_thinking_enabled is not None
109
+ else self._settings.resolve_thinking(direct_provider_model)
110
+ )
111
+ logger.debug(
112
+ "MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}",
113
+ claude_model_name,
114
+ direct_provider_id,
115
+ direct_provider_model,
116
+ thinking_enabled,
117
+ )
118
+ return ResolvedModel(
119
+ original_model=claude_model_name,
120
+ provider_id=direct_provider_id,
121
+ provider_model=direct_provider_model,
122
+ provider_model_ref=claude_model_name,
123
+ thinking_enabled=thinking_enabled,
124
+ )
125
+
126
+ provider_model_ref = self._settings.resolve_model(claude_model_name)
127
+ thinking_enabled = self._settings.resolve_thinking(claude_model_name)
128
+ provider_id = Settings.parse_provider_type(provider_model_ref)
129
+ provider_model = Settings.parse_model_name(provider_model_ref)
130
+ if provider_model != claude_model_name:
131
+ logger.debug(
132
+ "MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
133
+ )
134
+ return ResolvedModel(
135
+ original_model=claude_model_name,
136
+ provider_id=provider_id,
137
+ provider_model=provider_model,
138
+ provider_model_ref=provider_model_ref,
139
+ thinking_enabled=thinking_enabled,
140
+ )
141
+
142
+ def resolve_candidates(self, claude_model_name: str) -> list[ResolvedModel]:
143
+ """Resolve a model name to a prioritized list of candidates.
144
+
145
+ Used by the 'auto' routing logic to implement provider-side failover.
146
+ """
147
+ if not self._is_auto(claude_model_name):
148
+ return [self.resolve(claude_model_name)]
149
+
150
+ healthy_candidates: list[ResolvedModel] = []
151
+ blocked_candidates: list[ResolvedModel] = []
152
+ seen: set[str] = set()
153
+
154
+
155
+ def add_candidate(ref: str | None, source: str) -> None:
156
+ normalized_ref = self._normalize_candidate_ref(ref or "")
157
+ if normalized_ref is None or normalized_ref in seen:
158
+ return
159
+ provider_id = Settings.parse_provider_type(normalized_ref)
160
+ provider_model = Settings.parse_model_name(normalized_ref)
161
+ if self._settings.provider_is_configured(provider_id):
162
+ seen.add(normalized_ref)
163
+ resolved = ResolvedModel(
164
+ original_model=claude_model_name,
165
+ provider_id=provider_id,
166
+ provider_model=provider_model,
167
+ provider_model_ref=normalized_ref,
168
+ thinking_enabled=self._settings.resolve_thinking(claude_model_name),
169
+ )
170
+
171
+ limiter = GlobalRateLimiter.get_scoped_instance(provider_id)
172
+ if limiter.is_blocked():
173
+ logger.debug(
174
+ "Routing: candidate '{}' (from {}) is BLOCKED",
175
+ normalized_ref,
176
+ source,
177
+ )
178
+ blocked_candidates.append(resolved)
179
+ else:
180
+ logger.debug(
181
+ "Routing: added candidate '{}' (from {})",
182
+ normalized_ref,
183
+ source,
184
+ )
185
+ healthy_candidates.append(resolved)
186
+ else:
187
+ logger.debug(
188
+ "Routing: candidate '{}' (from {}) is NOT CONFIGURED",
189
+ normalized_ref,
190
+ source,
191
+ )
192
+
193
+ # 1. Preferred order (AUTO_MODEL_ORDER)
194
+ order_csv = (self._settings.auto_model_order or "").strip()
195
+ if order_csv:
196
+ for cand in [c.strip() for c in order_csv.split(",") if c.strip()]:
197
+ add_candidate(cand, "AUTO_MODEL_PRIORITY")
198
+
199
+ # 2. Main MODEL
200
+ add_candidate(self._settings.model, "MODEL")
201
+
202
+ # 3. NVIDIA Fallbacks
203
+ nim_csv = (self._settings.nvidia_nim_fallback_models or "").strip()
204
+ if nim_csv:
205
+ for cand in [c.strip() for c in nim_csv.split(",") if c.strip()]:
206
+ add_candidate(cand, "NVIDIA_NIM_FALLBACK_MODELS")
207
+
208
+ # 4. Model-specific overrides
209
+ add_candidate(self._settings.model_opus, "MODEL_OPUS")
210
+ add_candidate(self._settings.model_sonnet, "MODEL_SONNET")
211
+ add_candidate(self._settings.model_haiku, "MODEL_HAIKU")
212
+
213
+ all_candidates = healthy_candidates + blocked_candidates
214
+ logger.info(
215
+ "Routing: resolved '{}' to {} candidates: {}",
216
+ claude_model_name,
217
+ len(all_candidates),
218
+ ", ".join(c.provider_model_ref for c in all_candidates),
219
+ )
220
+ return all_candidates
221
+
222
+ def _direct_provider_model(
223
+ self, model_name: str
224
+ ) -> tuple[str | None, str | None, bool | None]:
225
+ decoded = decode_gateway_model_id(model_name)
226
+ if decoded is not None:
227
+ if decoded.provider_id not in SUPPORTED_PROVIDER_IDS:
228
+ return None, None, None
229
+ return (
230
+ decoded.provider_id,
231
+ decoded.provider_model,
232
+ decoded.force_thinking_enabled,
233
+ )
234
+
235
+ provider_id, separator, provider_model = model_name.partition("/")
236
+ if not separator:
237
+ return None, None, None
238
+ if provider_id not in SUPPORTED_PROVIDER_IDS:
239
+ return None, None, None
240
+ if not provider_model:
241
+ return None, None, None
242
+ return provider_id, provider_model, None
243
+
244
+ def resolve_messages_request(
245
+ self, request: MessagesRequest
246
+ ) -> RoutedMessagesRequest:
247
+ """Return an internal routed request context."""
248
+ resolved = self.resolve(request.model)
249
+ routed = request.model_copy(deep=True)
250
+ routed.model = resolved.provider_model
251
+ return RoutedMessagesRequest(request=routed, resolved=resolved)
252
+
253
+ def resolve_token_count_request(
254
+ self, request: TokenCountRequest
255
+ ) -> RoutedTokenCountRequest:
256
+ """Return an internal token-count request context."""
257
+ resolved = self.resolve(request.model)
258
+ routed = request.model_copy(
259
+ update={"model": resolved.provider_model}, deep=True
260
+ )
261
+ return RoutedTokenCountRequest(request=routed, resolved=resolved)
api/models/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """API models exports."""
2
+
3
+ from .anthropic import (
4
+ ContentBlockImage,
5
+ ContentBlockRedactedThinking,
6
+ ContentBlockText,
7
+ ContentBlockThinking,
8
+ ContentBlockToolResult,
9
+ ContentBlockToolUse,
10
+ Message,
11
+ MessagesRequest,
12
+ Role,
13
+ SystemContent,
14
+ ThinkingConfig,
15
+ TokenCountRequest,
16
+ Tool,
17
+ )
18
+ from .responses import (
19
+ MessagesResponse,
20
+ ModelResponse,
21
+ ModelsListResponse,
22
+ TokenCountResponse,
23
+ Usage,
24
+ )
25
+
26
+ __all__ = [
27
+ "ContentBlockImage",
28
+ "ContentBlockRedactedThinking",
29
+ "ContentBlockText",
30
+ "ContentBlockThinking",
31
+ "ContentBlockToolResult",
32
+ "ContentBlockToolUse",
33
+ "Message",
34
+ "MessagesRequest",
35
+ "MessagesResponse",
36
+ "ModelResponse",
37
+ "ModelsListResponse",
38
+ "Role",
39
+ "SystemContent",
40
+ "ThinkingConfig",
41
+ "TokenCountRequest",
42
+ "TokenCountResponse",
43
+ "Tool",
44
+ "Usage",
45
+ ]
api/models/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (849 Bytes). View file
 
api/models/__pycache__/anthropic.cpython-314.pyc ADDED
Binary file (11.6 kB). View file
 
api/models/__pycache__/responses.cpython-314.pyc ADDED
Binary file (3.69 kB). View file
 
api/models/anthropic.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for Anthropic-compatible requests."""
2
+
3
+ from enum import StrEnum
4
+ from typing import Any, Literal
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+
9
+ # =============================================================================
10
+ # Content Block Types
11
+ # =============================================================================
12
+ class Role(StrEnum):
13
+ user = "user"
14
+ assistant = "assistant"
15
+ system = "system"
16
+
17
+
18
+ class _AnthropicBlockBase(BaseModel):
19
+ """Pass through provider fields (e.g. ``cache_control``) for native transports."""
20
+
21
+ model_config = ConfigDict(extra="allow")
22
+
23
+
24
+ class ContentBlockText(_AnthropicBlockBase):
25
+ type: Literal["text"]
26
+ text: str
27
+
28
+
29
+ class ContentBlockImage(_AnthropicBlockBase):
30
+ type: Literal["image"]
31
+ source: dict[str, Any]
32
+
33
+
34
+ class ContentBlockToolUse(_AnthropicBlockBase):
35
+ type: Literal["tool_use"]
36
+ id: str
37
+ name: str
38
+ input: dict[str, Any]
39
+
40
+
41
+ class ContentBlockToolResult(_AnthropicBlockBase):
42
+ type: Literal["tool_result"]
43
+ tool_use_id: str
44
+ content: str | list[Any] | dict[str, Any]
45
+
46
+
47
+ class ContentBlockThinking(_AnthropicBlockBase):
48
+ type: Literal["thinking"]
49
+ thinking: str
50
+ signature: str | None = None
51
+
52
+
53
+ class ContentBlockRedactedThinking(_AnthropicBlockBase):
54
+ type: Literal["redacted_thinking"]
55
+ data: str
56
+
57
+
58
+ class ContentBlockServerToolUse(_AnthropicBlockBase):
59
+ """Anthropic server-side tool invocation (e.g. ``web_search``, ``web_fetch``)."""
60
+
61
+ type: Literal["server_tool_use"]
62
+ id: str
63
+ name: str
64
+ input: dict[str, Any]
65
+
66
+
67
+ class ContentBlockWebSearchToolResult(_AnthropicBlockBase):
68
+ type: Literal["web_search_tool_result"]
69
+ tool_use_id: str
70
+ content: Any
71
+
72
+
73
+ class ContentBlockWebFetchToolResult(_AnthropicBlockBase):
74
+ type: Literal["web_fetch_tool_result"]
75
+ tool_use_id: str
76
+ content: Any
77
+
78
+
79
+ class SystemContent(_AnthropicBlockBase):
80
+ type: Literal["text"]
81
+ text: str
82
+
83
+
84
+ # =============================================================================
85
+ # Message Types
86
+ # =============================================================================
87
+ class Message(BaseModel):
88
+ role: Literal["user", "assistant"]
89
+ content: (
90
+ str
91
+ | list[
92
+ ContentBlockText
93
+ | ContentBlockImage
94
+ | ContentBlockToolUse
95
+ | ContentBlockToolResult
96
+ | ContentBlockThinking
97
+ | ContentBlockRedactedThinking
98
+ | ContentBlockServerToolUse
99
+ | ContentBlockWebSearchToolResult
100
+ | ContentBlockWebFetchToolResult
101
+ ]
102
+ )
103
+ reasoning_content: str | None = None
104
+
105
+
106
+ class Tool(_AnthropicBlockBase):
107
+ name: str
108
+ # Anthropic server tools (e.g. web_search beta tools) include a ``type`` and
109
+ # may omit ``input_schema`` because the provider owns the schema.
110
+ type: str | None = None
111
+ description: str | None = None
112
+ input_schema: dict[str, Any] | None = None
113
+
114
+
115
+ class ThinkingConfig(BaseModel):
116
+ enabled: bool | None = True
117
+ type: str | None = None
118
+ budget_tokens: int | None = None
119
+
120
+
121
+ # =============================================================================
122
+ # Request Models
123
+ # =============================================================================
124
+ class MessagesRequest(BaseModel):
125
+ model_config = ConfigDict(extra="allow")
126
+
127
+ model: str
128
+ # Internal routing / debug: accepted on parse but not serialized to providers.
129
+ original_model: str | None = Field(default=None, exclude=True)
130
+ resolved_provider_model: str | None = Field(default=None, exclude=True)
131
+ max_tokens: int | None = None
132
+ messages: list[Message]
133
+ system: str | list[SystemContent] | None = None
134
+ stop_sequences: list[str] | None = None
135
+ stream: bool | None = True
136
+ temperature: float | None = None
137
+ top_p: float | None = None
138
+ top_k: int | None = None
139
+ metadata: dict[str, Any] | None = None
140
+ tools: list[Tool] | None = None
141
+ tool_choice: dict[str, Any] | None = None
142
+ thinking: ThinkingConfig | None = None
143
+ # Native Anthropic / SDK client hints: ignored (not forwarded) for OpenAI Chat conversion.
144
+ context_management: dict[str, Any] | None = None
145
+ output_config: dict[str, Any] | None = None
146
+ mcp_servers: list[dict[str, Any]] | None = None
147
+ extra_body: dict[str, Any] | None = None
148
+
149
+
150
+ class TokenCountRequest(BaseModel):
151
+ model_config = ConfigDict(extra="allow")
152
+
153
+ model: str
154
+ original_model: str | None = Field(default=None, exclude=True)
155
+ resolved_provider_model: str | None = Field(default=None, exclude=True)
156
+ messages: list[Message]
157
+ system: str | list[SystemContent] | None = None
158
+ tools: list[Tool] | None = None
159
+ thinking: ThinkingConfig | None = None
160
+ tool_choice: dict[str, Any] | None = None
161
+ context_management: dict[str, Any] | None = None
162
+ output_config: dict[str, Any] | None = None
163
+ mcp_servers: list[dict[str, Any]] | None = None
api/models/responses.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for API responses."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from .anthropic import (
8
+ ContentBlockRedactedThinking,
9
+ ContentBlockText,
10
+ ContentBlockThinking,
11
+ ContentBlockToolUse,
12
+ )
13
+
14
+
15
+ class TokenCountResponse(BaseModel):
16
+ input_tokens: int
17
+
18
+
19
+ class ModelResponse(BaseModel):
20
+ created_at: str
21
+ display_name: str
22
+ id: str
23
+ type: Literal["model"] = "model"
24
+
25
+
26
+ class ModelsListResponse(BaseModel):
27
+ data: list[ModelResponse]
28
+ first_id: str | None
29
+ has_more: bool
30
+ last_id: str | None
31
+
32
+
33
+ class Usage(BaseModel):
34
+ input_tokens: int
35
+ output_tokens: int
36
+ cache_creation_input_tokens: int = 0
37
+ cache_read_input_tokens: int = 0
38
+
39
+
40
+ class MessagesResponse(BaseModel):
41
+ id: str
42
+ model: str
43
+ role: Literal["assistant"] = "assistant"
44
+ content: list[
45
+ ContentBlockText
46
+ | ContentBlockToolUse
47
+ | ContentBlockThinking
48
+ | ContentBlockRedactedThinking
49
+ | dict[str, Any]
50
+ ]
51
+ type: Literal["message"] = "message"
52
+ stop_reason: (
53
+ Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None
54
+ ) = None
55
+ stop_sequence: str | None = None
56
+ usage: Usage
api/optimization_handlers.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optimization handlers for fast-path API responses.
2
+
3
+ Each handler returns a MessagesResponse if the request matches and the
4
+ optimization is enabled, otherwise None.
5
+ """
6
+
7
+ import uuid
8
+
9
+ from loguru import logger
10
+
11
+ from config.settings import Settings
12
+
13
+ from .command_utils import extract_command_prefix, extract_filepaths_from_command
14
+ from .detection import (
15
+ is_filepath_extraction_request,
16
+ is_prefix_detection_request,
17
+ is_quota_check_request,
18
+ is_suggestion_mode_request,
19
+ is_title_generation_request,
20
+ )
21
+ from .models.anthropic import MessagesRequest
22
+ from .models.responses import MessagesResponse, Usage
23
+
24
+
25
+ def _text_response(
26
+ request_data: MessagesRequest,
27
+ text: str,
28
+ *,
29
+ input_tokens: int,
30
+ output_tokens: int,
31
+ ) -> MessagesResponse:
32
+ return MessagesResponse(
33
+ id=f"msg_{uuid.uuid4()}",
34
+ model=request_data.model,
35
+ content=[{"type": "text", "text": text}],
36
+ stop_reason="end_turn",
37
+ usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens),
38
+ )
39
+
40
+
41
+ def try_prefix_detection(
42
+ request_data: MessagesRequest, settings: Settings
43
+ ) -> MessagesResponse | None:
44
+ """Fast prefix detection - return command prefix without API call."""
45
+ if not settings.fast_prefix_detection:
46
+ return None
47
+
48
+ is_prefix_req, command = is_prefix_detection_request(request_data)
49
+ if not is_prefix_req:
50
+ return None
51
+
52
+ logger.info("Optimization: Fast prefix detection request")
53
+ return _text_response(
54
+ request_data,
55
+ extract_command_prefix(command),
56
+ input_tokens=100,
57
+ output_tokens=5,
58
+ )
59
+
60
+
61
+ def try_quota_mock(
62
+ request_data: MessagesRequest, settings: Settings
63
+ ) -> MessagesResponse | None:
64
+ """Mock quota probe requests."""
65
+ if not settings.enable_network_probe_mock:
66
+ return None
67
+ if not is_quota_check_request(request_data):
68
+ return None
69
+
70
+ logger.info("Optimization: Intercepted and mocked quota probe")
71
+ return _text_response(
72
+ request_data,
73
+ "Quota check passed.",
74
+ input_tokens=10,
75
+ output_tokens=5,
76
+ )
77
+
78
+
79
+ def try_title_skip(
80
+ request_data: MessagesRequest, settings: Settings
81
+ ) -> MessagesResponse | None:
82
+ """Skip title generation requests."""
83
+ if not settings.enable_title_generation_skip:
84
+ return None
85
+ if not is_title_generation_request(request_data):
86
+ return None
87
+
88
+ logger.info("Optimization: Skipped title generation request")
89
+ return _text_response(
90
+ request_data,
91
+ "Conversation",
92
+ input_tokens=100,
93
+ output_tokens=5,
94
+ )
95
+
96
+
97
+ def try_suggestion_skip(
98
+ request_data: MessagesRequest, settings: Settings
99
+ ) -> MessagesResponse | None:
100
+ """Skip suggestion mode requests."""
101
+ if not settings.enable_suggestion_mode_skip:
102
+ return None
103
+ if not is_suggestion_mode_request(request_data):
104
+ return None
105
+
106
+ logger.info("Optimization: Skipped suggestion mode request")
107
+ return _text_response(
108
+ request_data,
109
+ "",
110
+ input_tokens=100,
111
+ output_tokens=1,
112
+ )
113
+
114
+
115
+ def try_filepath_mock(
116
+ request_data: MessagesRequest, settings: Settings
117
+ ) -> MessagesResponse | None:
118
+ """Mock filepath extraction requests."""
119
+ if not settings.enable_filepath_extraction_mock:
120
+ return None
121
+
122
+ is_fp, cmd, output = is_filepath_extraction_request(request_data)
123
+ if not is_fp:
124
+ return None
125
+
126
+ filepaths = extract_filepaths_from_command(cmd, output)
127
+ logger.info("Optimization: Mocked filepath extraction")
128
+ return _text_response(
129
+ request_data,
130
+ filepaths,
131
+ input_tokens=100,
132
+ output_tokens=10,
133
+ )
134
+
135
+
136
+ # Cheapest/most common optimizations first for faster short-circuit.
137
+ OPTIMIZATION_HANDLERS = [
138
+ try_quota_mock,
139
+ try_prefix_detection,
140
+ try_title_skip,
141
+ try_suggestion_skip,
142
+ try_filepath_mock,
143
+ ]
144
+
145
+
146
+ def try_optimizations(
147
+ request_data: MessagesRequest, settings: Settings
148
+ ) -> MessagesResponse | None:
149
+ """Run optimization handlers in order. Returns first match or None."""
150
+ for handler in OPTIMIZATION_HANDLERS:
151
+ result = handler(request_data, settings)
152
+ if result is not None:
153
+ return result
154
+ return None
api/routes.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI route handlers."""
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, Request, Response
4
+ from loguru import logger
5
+
6
+ from config.settings import Settings
7
+ from core.anthropic import get_token_count
8
+ from providers.registry import ProviderRegistry
9
+
10
+ from . import dependencies
11
+ from .dependencies import get_settings, require_api_key
12
+ from .gateway_model_ids import gateway_model_id, no_thinking_gateway_model_id
13
+ from .models.anthropic import MessagesRequest, TokenCountRequest
14
+ from .models.responses import ModelResponse, ModelsListResponse
15
+ from .services import ClaudeProxyService
16
+ from providers.nvidia_nim import metrics as nvidia_nim_metrics
17
+
18
+ router = APIRouter()
19
+
20
+ DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z"
21
+
22
+
23
+ # The proxy advertises a curated set of provider-backed models. Replace
24
+ # the previous hardcoded Claude model list with the requested NVIDIA-
25
+ # compatible models so clients only see those options.
26
+ REQUESTED_PROVIDER_MODELS = [
27
+ "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
28
+ "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512",
29
+ "nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct",
30
+ "nvidia_nim/z-ai/glm4.7",
31
+ "nvidia_nim/stepfun-ai/step-3.5-flash",
32
+ "nvidia_nim/bytedance/seed-oss-36b-instruct",
33
+ "nvidia_nim/mistralai/mistral-nemotron",
34
+ "groq/openai/gpt-oss-120b",
35
+ "groq/openai/gpt-oss-20b",
36
+ "groq/llama-3.3-70b-versatile",
37
+ "groq/meta-llama/llama-4-scout-17b-16e-instruct",
38
+ "groq/qwen/qwen3-32b",
39
+ "cerebras/gpt-oss-120b",
40
+ "cerebras/qwen-3-235b-a22b-instruct-2507",
41
+ "cerebras/zai-glm-4.7",
42
+ "cerebras/llama3.1-8b",
43
+ ]
44
+
45
+
46
+ def get_proxy_service(
47
+ request: Request,
48
+ settings: Settings = Depends(get_settings),
49
+ ) -> ClaudeProxyService:
50
+ """Build the request service for route handlers."""
51
+ return ClaudeProxyService(
52
+ settings,
53
+ provider_getter=lambda provider_type: dependencies.resolve_provider(
54
+ provider_type, app=request.app, settings=settings
55
+ ),
56
+ token_counter=get_token_count,
57
+ )
58
+
59
+
60
+ def _probe_response(allow: str) -> Response:
61
+ """Return an empty success response for compatibility probes."""
62
+ return Response(status_code=204, headers={"Allow": allow})
63
+
64
+
65
+ def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse:
66
+ return ModelResponse(
67
+ id=model_id,
68
+ display_name=display_name,
69
+ created_at=DISCOVERED_MODEL_CREATED_AT,
70
+ )
71
+
72
+
73
+ def _append_unique_model(
74
+ models: list[ModelResponse], seen: set[str], model: ModelResponse
75
+ ) -> None:
76
+ if model.id in seen:
77
+ return
78
+ seen.add(model.id)
79
+ models.append(model)
80
+
81
+
82
+ def _append_provider_model_variants(
83
+ models: list[ModelResponse],
84
+ seen: set[str],
85
+ provider_model_ref: str,
86
+ *,
87
+ supports_thinking: bool | None = None,
88
+ ) -> None:
89
+ if supports_thinking is not False:
90
+ _append_unique_model(
91
+ models,
92
+ seen,
93
+ _discovered_model_response(
94
+ gateway_model_id(provider_model_ref),
95
+ display_name=provider_model_ref,
96
+ ),
97
+ )
98
+ _append_unique_model(
99
+ models,
100
+ seen,
101
+ _discovered_model_response(
102
+ no_thinking_gateway_model_id(provider_model_ref),
103
+ display_name=f"{provider_model_ref} (no thinking)",
104
+ ),
105
+ )
106
+
107
+
108
+ def _build_models_list_response(
109
+ settings: Settings, provider_registry: ProviderRegistry | None
110
+ ) -> ModelsListResponse:
111
+ models: list[ModelResponse] = []
112
+ seen: set[str] = set()
113
+
114
+ # Advertise only the requested provider models (no Claude models, no registry auto-discovery).
115
+ # Each ref is added with both thinking and no-thinking variants.
116
+ for provider_ref in REQUESTED_PROVIDER_MODELS:
117
+ # If the ref already contains a provider prefix, use it as-is;
118
+ # otherwise assume it belongs to the NVIDIA NIM provider.
119
+ ref = provider_ref if "/" in provider_ref else f"nvidia_nim/{provider_ref}"
120
+ supports_thinking = None
121
+ if provider_registry is not None:
122
+ # model_id for registry lookups should be provider-prefixed
123
+ provider, model_id = ref.split("/", 1) if "/" in ref else ("nvidia_nim", ref)
124
+ supports_thinking = provider_registry.cached_model_supports_thinking(provider, model_id)
125
+ _append_provider_model_variants(models, seen, ref, supports_thinking=supports_thinking)
126
+
127
+ # Add a virtual `auto` model that maps to the configured MODEL and enables
128
+ # automatic fallback behavior when used by clients.
129
+ _append_unique_model(
130
+ models,
131
+ seen,
132
+ ModelResponse(
133
+ id=gateway_model_id("auto"),
134
+ display_name="auto (use configured fallbacks)",
135
+ created_at=DISCOVERED_MODEL_CREATED_AT,
136
+ ),
137
+ )
138
+
139
+ # Filter out any residual Claude-branded models so the proxy advertises
140
+ # only the provider-backed models requested by the user.
141
+ filtered = [
142
+ m
143
+ for m in models
144
+ if "claude" not in (m.id or "").lower() and "claude" not in (m.display_name or "").lower()
145
+ ]
146
+ # Ensure `auto` model remains available even if filtering removed others.
147
+ if not any(m.id == gateway_model_id("auto") for m in filtered):
148
+ filtered.append(
149
+ ModelResponse(
150
+ id=gateway_model_id("auto"),
151
+ display_name="auto (use configured fallbacks)",
152
+ created_at=DISCOVERED_MODEL_CREATED_AT,
153
+ )
154
+ )
155
+
156
+ return ModelsListResponse(
157
+ data=filtered,
158
+ first_id=filtered[0].id if filtered else None,
159
+ has_more=False,
160
+ last_id=filtered[-1].id if filtered else None,
161
+ )
162
+
163
+
164
+
165
+ # =============================================================================
166
+ # Routes
167
+ # =============================================================================
168
+ @router.post("/v1/messages")
169
+ async def create_message(
170
+ request_data: MessagesRequest,
171
+ service: ClaudeProxyService = Depends(get_proxy_service),
172
+ _auth=Depends(require_api_key),
173
+ ):
174
+ """Create a message (always streaming)."""
175
+ return service.create_message(request_data)
176
+
177
+
178
+ @router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
179
+ async def probe_messages(_auth=Depends(require_api_key)):
180
+ """Respond to Claude compatibility probes for the messages endpoint."""
181
+ return _probe_response("POST, HEAD, OPTIONS")
182
+
183
+
184
+ @router.post("/v1/messages/count_tokens")
185
+ async def count_tokens(
186
+ request_data: TokenCountRequest,
187
+ service: ClaudeProxyService = Depends(get_proxy_service),
188
+ _auth=Depends(require_api_key),
189
+ ):
190
+ """Count tokens for a request."""
191
+ return service.count_tokens(request_data)
192
+
193
+
194
+ @router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
195
+ async def probe_count_tokens(_auth=Depends(require_api_key)):
196
+ """Respond to Claude compatibility probes for the token count endpoint."""
197
+ return _probe_response("POST, HEAD, OPTIONS")
198
+
199
+
200
+ @router.get("/")
201
+ async def root(
202
+ settings: Settings = Depends(get_settings), _auth=Depends(require_api_key)
203
+ ):
204
+ """Root endpoint."""
205
+ return {
206
+ "status": "ok",
207
+ "provider": settings.provider_type,
208
+ "model": settings.model,
209
+ }
210
+
211
+
212
+ @router.api_route("/", methods=["HEAD", "OPTIONS"])
213
+ async def probe_root(_auth=Depends(require_api_key)):
214
+ """Respond to compatibility probes for the root endpoint."""
215
+ return _probe_response("GET, HEAD, OPTIONS")
216
+
217
+
218
+ @router.get("/health")
219
+ async def health():
220
+ """Health check endpoint."""
221
+ return {"status": "healthy"}
222
+
223
+
224
+ @router.api_route("/health", methods=["HEAD", "OPTIONS"])
225
+ async def probe_health():
226
+ """Respond to compatibility probes for the health endpoint."""
227
+ return _probe_response("GET, HEAD, OPTIONS")
228
+
229
+
230
+ @router.get("/v1/models", response_model=ModelsListResponse)
231
+ async def list_models(
232
+ request: Request,
233
+ settings: Settings = Depends(get_settings),
234
+ _auth=Depends(require_api_key),
235
+ ):
236
+ """List the model ids this proxy advertises to Claude-compatible clients."""
237
+ registry = getattr(request.app.state, "provider_registry", None)
238
+ provider_registry = registry if isinstance(registry, ProviderRegistry) else None
239
+ return _build_models_list_response(settings, provider_registry)
240
+
241
+
242
+ @router.post("/stop")
243
+ async def stop_cli(request: Request, _auth=Depends(require_api_key)):
244
+ """Stop all CLI sessions and pending tasks."""
245
+ handler = getattr(request.app.state, "message_handler", None)
246
+ if not handler:
247
+ # Fallback if messaging not initialized
248
+ cli_manager = getattr(request.app.state, "cli_manager", None)
249
+ if cli_manager:
250
+ await cli_manager.stop_all()
251
+ logger.info("STOP_CLI: source=cli_manager cancelled_count=N/A")
252
+ return {"status": "stopped", "source": "cli_manager"}
253
+ raise HTTPException(status_code=503, detail="Messaging system not initialized")
254
+
255
+ count = await handler.stop_all_tasks()
256
+ logger.info("STOP_CLI: source=handler cancelled_count={}", count)
257
+ return {"status": "stopped", "cancelled_count": count}
258
+
259
+
260
+ @router.get("/admin/fallbacks")
261
+ async def admin_fallbacks(_auth=Depends(require_api_key)):
262
+ """Admin endpoint exposing NVIDIA NIM fallback metrics.
263
+
264
+ Protected by the same API key as other endpoints.
265
+ """
266
+ try:
267
+ data = nvidia_nim_metrics.snapshot()
268
+ except Exception as e:
269
+ logger.warning("ADMIN_FALLBACKS: failed to read metrics: {}", e)
270
+ raise HTTPException(status_code=500, detail="failed to read metrics")
271
+ return {"provider": "nvidia_nim", "fallbacks": data}
api/runtime.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application runtime composition and lifecycle ownership."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from fastapi import FastAPI
11
+ from loguru import logger
12
+
13
+ from config.settings import Settings, get_settings
14
+ from providers.exceptions import ServiceUnavailableError
15
+ from providers.registry import ProviderRegistry
16
+
17
+ if TYPE_CHECKING:
18
+ from cli.manager import CLISessionManager
19
+ from messaging.handler import ClaudeMessageHandler
20
+ from messaging.platforms.base import MessagingPlatform
21
+ from messaging.session import SessionStore
22
+
23
+ _SHUTDOWN_TIMEOUT_S = 5.0
24
+
25
+
26
+ async def best_effort(
27
+ name: str,
28
+ awaitable: Any,
29
+ timeout_s: float = _SHUTDOWN_TIMEOUT_S,
30
+ *,
31
+ log_verbose_errors: bool = False,
32
+ ) -> None:
33
+ """Run a shutdown step with timeout; never raise to callers."""
34
+ try:
35
+ await asyncio.wait_for(awaitable, timeout=timeout_s)
36
+ except TimeoutError:
37
+ logger.warning("Shutdown step timed out: {} ({}s)", name, timeout_s)
38
+ except Exception as e:
39
+ if log_verbose_errors:
40
+ logger.warning(
41
+ "Shutdown step failed: {}: {}: {}",
42
+ name,
43
+ type(e).__name__,
44
+ e,
45
+ )
46
+ else:
47
+ logger.warning(
48
+ "Shutdown step failed: {}: exc_type={}",
49
+ name,
50
+ type(e).__name__,
51
+ )
52
+
53
+
54
+ def warn_if_process_auth_token(settings: Settings) -> None:
55
+ """Warn when server auth was implicitly inherited from the shell."""
56
+ if settings.uses_process_anthropic_auth_token():
57
+ logger.warning(
58
+ "ANTHROPIC_AUTH_TOKEN is set in the process environment but not in "
59
+ "a configured .env file. The proxy will require that token. Add "
60
+ "ANTHROPIC_AUTH_TOKEN= to .env to disable proxy auth, or set the "
61
+ "same token in .env to make server auth explicit."
62
+ )
63
+
64
+
65
+ def log_startup_failure(settings: Settings, exc: Exception) -> None:
66
+ """Log startup failures without traceback noise unless verbose diagnostics are enabled."""
67
+ message = startup_failure_message(settings, exc)
68
+ logger.error("Startup failed:\n{}", message)
69
+
70
+
71
+ def startup_failure_message(settings: Settings, exc: Exception) -> str:
72
+ """Return a concise startup failure message for logs and ASGI lifespan failure."""
73
+ if isinstance(exc, ServiceUnavailableError):
74
+ return exc.message.strip() or "Server startup failed."
75
+
76
+ if settings.log_api_error_tracebacks:
77
+ return f"{type(exc).__name__}: {exc}"
78
+
79
+ return f"Server startup failed: exc_type={type(exc).__name__}"
80
+
81
+
82
+ def _should_continue_after_model_validation_failure(exc: Exception) -> bool:
83
+ """Return whether a model-validation failure should be downgraded to a warning.
84
+
85
+ Provider discovery can fail transiently or due to local environment issues
86
+ (for example, a missing runtime dependency in the provider's process path).
87
+ We keep startup alive in those cases so the configured proxy can still serve
88
+ requests and advertise the models that are already known from settings.
89
+ """
90
+ if not isinstance(exc, ServiceUnavailableError):
91
+ return False
92
+
93
+ message = (exc.message or str(exc)).lower()
94
+ return "problem=query failure:" in message
95
+
96
+
97
+ @dataclass(slots=True)
98
+ class AppRuntime:
99
+ """Own optional messaging, CLI, session, and provider runtime resources."""
100
+
101
+ app: FastAPI
102
+ settings: Settings
103
+ _provider_registry: ProviderRegistry | None = field(default=None, init=False)
104
+ messaging_platform: MessagingPlatform | None = None
105
+ message_handler: ClaudeMessageHandler | None = None
106
+ cli_manager: CLISessionManager | None = None
107
+
108
+ @classmethod
109
+ def for_app(
110
+ cls,
111
+ app: FastAPI,
112
+ settings: Settings | None = None,
113
+ ) -> AppRuntime:
114
+ return cls(app=app, settings=settings or get_settings())
115
+
116
+ async def startup(self) -> None:
117
+ logger.info("Starting Claude Code Proxy...")
118
+ self._provider_registry = ProviderRegistry()
119
+ self.app.state.provider_registry = self._provider_registry
120
+ try:
121
+ warn_if_process_auth_token(self.settings)
122
+ try:
123
+ # Use a reasonable timeout for startup validation to prevent blocking healthy checks.
124
+ await asyncio.wait_for(
125
+ self._provider_registry.validate_configured_models(self.settings),
126
+ timeout=15.0,
127
+ )
128
+ except Exception as exc:
129
+ logger.warning(
130
+ "Startup model validation skipped or timed out: continuing in lazy mode. "
131
+ "Reason: {}",
132
+ str(exc) or type(exc).__name__,
133
+ )
134
+ self._provider_registry.start_model_list_refresh(self.settings)
135
+ await self._start_messaging_if_configured()
136
+ self._publish_state()
137
+ except Exception as exc:
138
+ log_startup_failure(self.settings, exc)
139
+ await best_effort(
140
+ "provider_registry.cleanup",
141
+ self._provider_registry.cleanup(),
142
+ log_verbose_errors=self.settings.log_api_error_tracebacks,
143
+ )
144
+ raise
145
+
146
+ async def shutdown(self) -> None:
147
+ verbose = self.settings.log_api_error_tracebacks
148
+ if self.message_handler is not None:
149
+ try:
150
+ self.message_handler.session_store.flush_pending_save()
151
+ except Exception as e:
152
+ if verbose:
153
+ logger.warning("Session store flush on shutdown: {}", e)
154
+ else:
155
+ logger.warning(
156
+ "Session store flush on shutdown: exc_type={}",
157
+ type(e).__name__,
158
+ )
159
+
160
+ logger.info("Shutdown requested, cleaning up...")
161
+ if self.messaging_platform:
162
+ await best_effort(
163
+ "messaging_platform.stop",
164
+ self.messaging_platform.stop(),
165
+ log_verbose_errors=verbose,
166
+ )
167
+ if self.cli_manager:
168
+ await best_effort(
169
+ "cli_manager.stop_all",
170
+ self.cli_manager.stop_all(),
171
+ log_verbose_errors=verbose,
172
+ )
173
+ if self._provider_registry is not None:
174
+ await best_effort(
175
+ "provider_registry.cleanup",
176
+ self._provider_registry.cleanup(),
177
+ log_verbose_errors=verbose,
178
+ )
179
+ await self._shutdown_limiter()
180
+ logger.info("Server shut down cleanly")
181
+
182
+ async def _start_messaging_if_configured(self) -> None:
183
+ try:
184
+ from messaging.platforms.factory import (
185
+ MessagingPlatformOptions,
186
+ create_messaging_platform,
187
+ )
188
+
189
+ self.messaging_platform = create_messaging_platform(
190
+ self.settings.messaging_platform,
191
+ MessagingPlatformOptions(
192
+ telegram_bot_token=self.settings.telegram_bot_token,
193
+ allowed_telegram_user_id=self.settings.allowed_telegram_user_id,
194
+ discord_bot_token=self.settings.discord_bot_token,
195
+ allowed_discord_channels=self.settings.allowed_discord_channels,
196
+ voice_note_enabled=self.settings.voice_note_enabled,
197
+ whisper_model=self.settings.whisper_model,
198
+ whisper_device=self.settings.whisper_device,
199
+ hf_token=self.settings.hf_token,
200
+ nvidia_nim_api_key=self.settings.nvidia_nim_api_key_qwen,
201
+ messaging_rate_limit=self.settings.messaging_rate_limit,
202
+ messaging_rate_window=self.settings.messaging_rate_window,
203
+ log_raw_messaging_content=self.settings.log_raw_messaging_content,
204
+ log_api_error_tracebacks=self.settings.log_api_error_tracebacks,
205
+ ),
206
+ )
207
+
208
+ if self.messaging_platform:
209
+ await self._start_message_handler()
210
+
211
+ except ImportError as e:
212
+ if self.settings.log_api_error_tracebacks:
213
+ logger.warning("Messaging module import error: {}", e)
214
+ else:
215
+ logger.warning(
216
+ "Messaging module import error: exc_type={}",
217
+ type(e).__name__,
218
+ )
219
+ except Exception as e:
220
+ if self.settings.log_api_error_tracebacks:
221
+ logger.error("Failed to start messaging platform: {}", e)
222
+ import traceback
223
+
224
+ logger.error(traceback.format_exc())
225
+ else:
226
+ logger.error(
227
+ "Failed to start messaging platform: exc_type={}",
228
+ type(e).__name__,
229
+ )
230
+
231
+ async def _start_message_handler(self) -> None:
232
+ from cli.manager import CLISessionManager
233
+ from messaging.handler import ClaudeMessageHandler
234
+ from messaging.session import SessionStore
235
+
236
+ workspace = (
237
+ os.path.abspath(self.settings.allowed_dir)
238
+ if self.settings.allowed_dir
239
+ else os.getcwd()
240
+ )
241
+ os.makedirs(workspace, exist_ok=True)
242
+
243
+ data_path = os.path.abspath(self.settings.claude_workspace)
244
+ os.makedirs(data_path, exist_ok=True)
245
+
246
+ api_url = f"http://{self.settings.host}:{self.settings.port}/v1"
247
+ allowed_dirs = [workspace] if self.settings.allowed_dir else []
248
+ plans_dir_abs = os.path.abspath(
249
+ os.path.join(self.settings.claude_workspace, "plans")
250
+ )
251
+ plans_directory = os.path.relpath(plans_dir_abs, workspace)
252
+ self.cli_manager = CLISessionManager(
253
+ workspace_path=workspace,
254
+ api_url=api_url,
255
+ allowed_dirs=allowed_dirs,
256
+ plans_directory=plans_directory,
257
+ claude_bin=self.settings.claude_cli_bin,
258
+ log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
259
+ log_messaging_error_details=self.settings.log_messaging_error_details,
260
+ )
261
+
262
+ session_store = SessionStore(
263
+ storage_path=os.path.join(data_path, "sessions.json"),
264
+ message_log_cap=self.settings.max_message_log_entries_per_chat,
265
+ )
266
+ platform = self.messaging_platform
267
+ assert platform is not None
268
+ self.message_handler = ClaudeMessageHandler(
269
+ platform=platform,
270
+ cli_manager=self.cli_manager,
271
+ session_store=session_store,
272
+ debug_platform_edits=self.settings.debug_platform_edits,
273
+ debug_subagent_stack=self.settings.debug_subagent_stack,
274
+ log_raw_messaging_content=self.settings.log_raw_messaging_content,
275
+ log_raw_cli_diagnostics=self.settings.log_raw_cli_diagnostics,
276
+ log_messaging_error_details=self.settings.log_messaging_error_details,
277
+ )
278
+ self._restore_tree_state(session_store)
279
+
280
+ platform.on_message(self.message_handler.handle_message)
281
+ await platform.start()
282
+ logger.info(f"{platform.name} platform started with message handler")
283
+
284
+ def _restore_tree_state(self, session_store: SessionStore) -> None:
285
+ saved_trees = session_store.get_all_trees()
286
+ if not saved_trees:
287
+ return
288
+ if self.message_handler is None:
289
+ return
290
+
291
+ logger.info(f"Restoring {len(saved_trees)} conversation trees...")
292
+ from messaging.trees.queue_manager import TreeQueueManager
293
+
294
+ self.message_handler.replace_tree_queue(
295
+ TreeQueueManager.from_dict(
296
+ {
297
+ "trees": saved_trees,
298
+ "node_to_tree": session_store.get_node_mapping(),
299
+ },
300
+ queue_update_callback=self.message_handler.update_queue_positions,
301
+ node_started_callback=self.message_handler.mark_node_processing,
302
+ )
303
+ )
304
+ if self.message_handler.tree_queue.cleanup_stale_nodes() > 0:
305
+ tree_data = self.message_handler.tree_queue.to_dict()
306
+ session_store.sync_from_tree_data(
307
+ tree_data["trees"], tree_data["node_to_tree"]
308
+ )
309
+
310
+ def _publish_state(self) -> None:
311
+ self.app.state.messaging_platform = self.messaging_platform
312
+ self.app.state.message_handler = self.message_handler
313
+ self.app.state.cli_manager = self.cli_manager
314
+
315
+ async def _shutdown_limiter(self) -> None:
316
+ verbose = self.settings.log_api_error_tracebacks
317
+ try:
318
+ from messaging.limiter import MessagingRateLimiter
319
+ except Exception as e:
320
+ if verbose:
321
+ logger.debug(
322
+ "Rate limiter shutdown skipped (import failed): {}: {}",
323
+ type(e).__name__,
324
+ e,
325
+ )
326
+ else:
327
+ logger.debug(
328
+ "Rate limiter shutdown skipped (import failed): exc_type={}",
329
+ type(e).__name__,
330
+ )
331
+ return
332
+
333
+ await best_effort(
334
+ "MessagingRateLimiter.shutdown_instance",
335
+ MessagingRateLimiter.shutdown_instance(),
336
+ timeout_s=2.0,
337
+ log_verbose_errors=verbose,
338
+ )
api/services.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application services for the Claude-compatible API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import traceback
6
+ import uuid
7
+ from collections.abc import AsyncIterator, Callable
8
+ from typing import Any
9
+
10
+ from fastapi import HTTPException
11
+ from fastapi.responses import StreamingResponse
12
+ from loguru import logger
13
+
14
+ from config.settings import Settings
15
+ from core.anthropic import get_token_count, get_user_facing_error_message
16
+ from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event
17
+ from providers.base import BaseProvider
18
+ from providers.exceptions import (
19
+ InvalidRequestError,
20
+ OverloadedError,
21
+ ProviderError,
22
+ RateLimitError,
23
+ )
24
+
25
+ from .model_router import ModelRouter
26
+ from .models.anthropic import MessagesRequest, TokenCountRequest
27
+ from .models.responses import TokenCountResponse
28
+ from .optimization_handlers import try_optimizations
29
+ from .web_tools.egress import WebFetchEgressPolicy
30
+ from .web_tools.request import (
31
+ is_web_server_tool_request,
32
+ openai_chat_upstream_server_tool_error,
33
+ )
34
+
35
+ TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
36
+
37
+ ProviderGetter = Callable[[str], BaseProvider]
38
+
39
+ # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
40
+ _OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras"})
41
+
42
+
43
+ def anthropic_sse_streaming_response(
44
+ body: AsyncIterator[str],
45
+ ) -> StreamingResponse:
46
+ """Return a :class:`StreamingResponse` for Anthropic-style SSE streams."""
47
+ return StreamingResponse(
48
+ body,
49
+ media_type="text/event-stream",
50
+ headers=ANTHROPIC_SSE_RESPONSE_HEADERS,
51
+ )
52
+
53
+
54
+ def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int:
55
+ """HTTP status for uncaught non-provider failures (stable client contract)."""
56
+ return 500
57
+
58
+
59
+ def _log_unexpected_service_exception(
60
+ settings: Settings,
61
+ exc: BaseException,
62
+ *,
63
+ context: str,
64
+ request_id: str | None = None,
65
+ ) -> None:
66
+ """Log service-layer failures without echoing exception text unless opted in."""
67
+ if settings.log_api_error_tracebacks:
68
+ if request_id is not None:
69
+ logger.error("{} request_id={}: {}", context, request_id, exc)
70
+ else:
71
+ logger.error("{}: {}", context, exc)
72
+ logger.error(traceback.format_exc())
73
+ return
74
+ if request_id is not None:
75
+ logger.error(
76
+ "{} request_id={} exc_type={}",
77
+ context,
78
+ request_id,
79
+ type(exc).__name__,
80
+ )
81
+ else:
82
+ logger.error("{} exc_type={}", context, type(exc).__name__)
83
+
84
+
85
+ def _require_non_empty_messages(messages: list[Any]) -> None:
86
+ if not messages:
87
+ raise InvalidRequestError("messages cannot be empty")
88
+
89
+
90
+ class ClaudeProxyService:
91
+ """Coordinate request optimization, model routing, token count, and providers."""
92
+
93
+ def __init__(
94
+ self,
95
+ settings: Settings,
96
+ provider_getter: ProviderGetter,
97
+ model_router: ModelRouter | None = None,
98
+ token_counter: TokenCounter = get_token_count,
99
+ ):
100
+ self._settings = settings
101
+ self._provider_getter = provider_getter
102
+ self._model_router = model_router or ModelRouter(settings)
103
+ self._token_counter = token_counter
104
+
105
+ def create_message(self, request_data: MessagesRequest) -> object:
106
+ """Create a message response or streaming response with optional failover."""
107
+ from .web_tools.streaming import stream_web_server_tool_response
108
+ try:
109
+ _require_non_empty_messages(request_data.messages)
110
+
111
+ candidates = self._model_router.resolve_candidates(request_data.model)
112
+ if not candidates:
113
+ raise InvalidRequestError(f"No configured models available for '{request_data.model}'")
114
+
115
+ # For 'auto' requests with multiple candidates, we wrap the stream in a failover loop.
116
+ if len(candidates) > 1:
117
+ return anthropic_sse_streaming_response(
118
+ self._stream_with_fallbacks(candidates, request_data)
119
+ )
120
+
121
+ # Standard path for single-model requests
122
+ return self._create_single_message(candidates[0], request_data)
123
+
124
+ except ProviderError:
125
+ raise
126
+ except Exception as e:
127
+ _log_unexpected_service_exception(
128
+ self._settings, e, context="CREATE_MESSAGE_ERROR"
129
+ )
130
+ raise HTTPException(
131
+ status_code=_http_status_for_unexpected_service_exception(e),
132
+ detail=get_user_facing_error_message(e),
133
+ ) from e
134
+
135
+ def _create_single_message(
136
+ self, resolved: ResolvedModel, request_data: MessagesRequest
137
+ ) -> object:
138
+ """Create a single message response from a resolved model."""
139
+ routed_request = request_data.model_copy(deep=True)
140
+ routed_request.model = resolved.provider_model
141
+
142
+ if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS:
143
+ tool_err = openai_chat_upstream_server_tool_error(
144
+ routed_request,
145
+ web_tools_enabled=self._settings.enable_web_server_tools,
146
+ )
147
+ if tool_err is not None:
148
+ raise InvalidRequestError(tool_err)
149
+
150
+ if self._settings.enable_web_server_tools and is_web_server_tool_request(
151
+ routed_request
152
+ ):
153
+ input_tokens = self._token_counter(
154
+ routed_request.messages, routed_request.system, routed_request.tools
155
+ )
156
+ logger.info("Optimization: Handling Anthropic web server tool")
157
+ egress = WebFetchEgressPolicy(
158
+ allow_private_network_targets=self._settings.web_fetch_allow_private_networks,
159
+ allowed_schemes=self._settings.web_fetch_allowed_scheme_set(),
160
+ )
161
+ return anthropic_sse_streaming_response(
162
+ stream_web_server_tool_response(
163
+ routed_request,
164
+ input_tokens=input_tokens,
165
+ web_fetch_egress=egress,
166
+ verbose_client_errors=self._settings.log_api_error_tracebacks,
167
+ ),
168
+ )
169
+
170
+ optimized = try_optimizations(routed_request, self._settings)
171
+ if optimized is not None:
172
+ return optimized
173
+
174
+ provider = self._provider_getter(resolved.provider_id)
175
+ provider.preflight_stream(
176
+ routed_request,
177
+ thinking_enabled=resolved.thinking_enabled,
178
+ )
179
+
180
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
181
+ logger.info(
182
+ "API_REQUEST: request_id={} model={} messages={}",
183
+ request_id,
184
+ routed_request.model,
185
+ len(routed_request.messages),
186
+ )
187
+
188
+ input_tokens = self._token_counter(
189
+ routed_request.messages, routed_request.system, routed_request.tools
190
+ )
191
+ return anthropic_sse_streaming_response(
192
+ provider.stream_response(
193
+ routed_request,
194
+ input_tokens=input_tokens,
195
+ request_id=request_id,
196
+ thinking_enabled=resolved.thinking_enabled,
197
+ ),
198
+ )
199
+
200
+ async def _stream_with_fallbacks(
201
+ self, candidates: list[ResolvedModel], request_data: MessagesRequest
202
+ ) -> AsyncIterator[str]:
203
+ """Iterate through candidates until one succeeds or all fail."""
204
+ last_exc: Exception | None = None
205
+
206
+ for i, resolved in enumerate(candidates):
207
+ try:
208
+ provider = self._provider_getter(resolved.provider_id)
209
+ routed_request = request_data.model_copy(deep=True)
210
+ routed_request.model = resolved.provider_model
211
+
212
+ provider.preflight_stream(
213
+ routed_request,
214
+ thinking_enabled=resolved.thinking_enabled,
215
+ )
216
+
217
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
218
+ logger.info(
219
+ "API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}",
220
+ i + 1,
221
+ len(candidates),
222
+ request_id,
223
+ resolved.provider_id,
224
+ resolved.provider_model,
225
+ )
226
+
227
+ input_tokens = self._token_counter(
228
+ routed_request.messages, routed_request.system, routed_request.tools
229
+ )
230
+
231
+ # Attempt to stream from this provider.
232
+ async for event in provider.stream_response(
233
+ routed_request,
234
+ input_tokens=input_tokens,
235
+ request_id=request_id,
236
+ thinking_enabled=resolved.thinking_enabled,
237
+ ):
238
+ yield event
239
+ # CRITICAL: If we have yielded even one event, we have committed to this provider.
240
+ # We must not fallback to another candidate mid-stream.
241
+ return # Success, exit the fallback loop.
242
+
243
+ except (RateLimitError, OverloadedError) as e:
244
+ logger.warning(
245
+ "Provider '{}' is rate limited or overloaded ({}). Trying next candidate...",
246
+ resolved.provider_id,
247
+ e.status_code,
248
+ )
249
+ last_exc = e
250
+ continue
251
+ except Exception as e:
252
+ logger.error(
253
+ "Provider '{}' failed with unexpected error: {}. Trying next candidate...",
254
+ resolved.provider_id,
255
+ e,
256
+ )
257
+ last_exc = e
258
+ continue
259
+
260
+ err_msg = str(last_exc) if last_exc else "No candidates succeeded"
261
+ yield format_sse_event(
262
+ "error",
263
+ {
264
+ "type": "error",
265
+ "error": {
266
+ "type": "api_error",
267
+ "message": f"All fallback candidates failed: {err_msg}",
268
+ },
269
+ },
270
+ )
271
+ if last_exc:
272
+ raise last_exc
273
+ raise InvalidRequestError("No candidates succeeded")
274
+
275
+ def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
276
+ """Count tokens for a request after applying configured model routing."""
277
+ request_id = f"req_{uuid.uuid4().hex[:12]}"
278
+ with logger.contextualize(request_id=request_id):
279
+ try:
280
+ _require_non_empty_messages(request_data.messages)
281
+ routed = self._model_router.resolve_token_count_request(request_data)
282
+ tokens = self._token_counter(
283
+ routed.request.messages, routed.request.system, routed.request.tools
284
+ )
285
+ logger.info(
286
+ "COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
287
+ request_id,
288
+ routed.request.model,
289
+ len(routed.request.messages),
290
+ tokens,
291
+ )
292
+ return TokenCountResponse(input_tokens=tokens)
293
+ except ProviderError:
294
+ raise
295
+ except Exception as e:
296
+ _log_unexpected_service_exception(
297
+ self._settings,
298
+ e,
299
+ context="COUNT_TOKENS_ERROR",
300
+ request_id=request_id,
301
+ )
302
+ raise HTTPException(
303
+ status_code=_http_status_for_unexpected_service_exception(e),
304
+ detail=get_user_facing_error_message(e),
305
+ ) from e
api/validation_log.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Safe metadata summaries for HTTP 422 validation logging (no raw text content)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ def summarize_request_validation_body(
9
+ body: Any,
10
+ ) -> tuple[list[dict[str, Any]], list[str]]:
11
+ """Return message shape summary and tool name list for debug logs."""
12
+ messages = body.get("messages") if isinstance(body, dict) else None
13
+ message_summary: list[dict[str, Any]] = []
14
+ if isinstance(messages, list):
15
+ for msg in messages:
16
+ if not isinstance(msg, dict):
17
+ message_summary.append({"message_kind": type(msg).__name__})
18
+ continue
19
+ content = msg.get("content")
20
+ item: dict[str, Any] = {
21
+ "role": msg.get("role"),
22
+ "content_kind": type(content).__name__,
23
+ }
24
+ if isinstance(content, list):
25
+ item["block_types"] = [
26
+ block.get("type", "dict")
27
+ if isinstance(block, dict)
28
+ else type(block).__name__
29
+ for block in content[:12]
30
+ ]
31
+ item["block_keys"] = [
32
+ sorted(str(key) for key in block)[:12]
33
+ for block in content[:5]
34
+ if isinstance(block, dict)
35
+ ]
36
+ elif isinstance(content, str):
37
+ item["content_length"] = len(content)
38
+ message_summary.append(item)
39
+
40
+ tool_names: list[str] = []
41
+ if isinstance(body, dict) and isinstance(body.get("tools"), list):
42
+ tool_names = [
43
+ str(tool.get("name", ""))
44
+ for tool in body["tools"]
45
+ if isinstance(tool, dict)
46
+ ]
47
+
48
+ return message_summary, tool_names
api/web_server_tools.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility re-exports for :mod:`api.web_tools` (web_search / web_fetch)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import httpx
6
+
7
+ from api.web_tools.egress import (
8
+ WebFetchEgressPolicy,
9
+ WebFetchEgressViolation,
10
+ enforce_web_fetch_egress,
11
+ )
12
+ from api.web_tools.request import is_web_server_tool_request
13
+ from api.web_tools.streaming import stream_web_server_tool_response
14
+
15
+ __all__ = [
16
+ "WebFetchEgressPolicy",
17
+ "WebFetchEgressViolation",
18
+ "enforce_web_fetch_egress",
19
+ "httpx",
20
+ "is_web_server_tool_request",
21
+ "stream_web_server_tool_response",
22
+ ]
api/web_tools/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Submodules for Anthropic web server tool handling (search/fetch, egress, streaming)."""
2
+
3
+ from .egress import (
4
+ WebFetchEgressPolicy,
5
+ WebFetchEgressViolation,
6
+ enforce_web_fetch_egress,
7
+ )
8
+ from .request import is_web_server_tool_request
9
+ from .streaming import stream_web_server_tool_response
10
+
11
+ __all__ = [
12
+ "WebFetchEgressPolicy",
13
+ "WebFetchEgressViolation",
14
+ "enforce_web_fetch_egress",
15
+ "is_web_server_tool_request",
16
+ "stream_web_server_tool_response",
17
+ ]
api/web_tools/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (571 Bytes). View file
 
api/web_tools/__pycache__/constants.cpython-314.pyc ADDED
Binary file (680 Bytes). View file
 
api/web_tools/__pycache__/egress.cpython-314.pyc ADDED
Binary file (5.32 kB). View file
 
api/web_tools/__pycache__/parsers.cpython-314.pyc ADDED
Binary file (8.64 kB). View file
 
api/web_tools/__pycache__/request.cpython-314.pyc ADDED
Binary file (6.48 kB). View file
 
api/web_tools/__pycache__/streaming.cpython-314.pyc ADDED
Binary file (6.6 kB). View file
 
api/web_tools/constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Limits and defaults for outbound web server tool HTTP."""
2
+
3
+ _REQUEST_TIMEOUT_S = 20.0
4
+ _MAX_SEARCH_RESULTS = 10
5
+ _MAX_FETCH_CHARS = 24_000
6
+ # Hard cap on raw bytes read from HTTP responses before decode / HTML parse (memory bound).
7
+ _MAX_WEB_FETCH_RESPONSE_BYTES = 2 * 1024 * 1024
8
+ # Drain at most this many bytes from redirect responses before following Location.
9
+ _REDIRECT_RESPONSE_BODY_CAP_BYTES = 65_536
10
+ _MAX_WEB_FETCH_REDIRECTS = 10
11
+ _WEB_FETCH_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
12
+
13
+ _WEB_TOOL_HTTP_HEADERS = {
14
+ "User-Agent": "Mozilla/5.0 compatible; free-claude-code/2.0",
15
+ }
api/web_tools/egress.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Egress policy for user-controlled web_fetch URLs (SSRF guard)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ipaddress
6
+ import socket
7
+ from dataclasses import dataclass
8
+ from urllib.parse import urlparse
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class WebFetchEgressPolicy:
13
+ """Egress rules for user-influenced web_fetch URLs."""
14
+
15
+ allow_private_network_targets: bool
16
+ allowed_schemes: frozenset[str]
17
+
18
+
19
+ class WebFetchEgressViolation(ValueError):
20
+ """Raised when a web_fetch URL is rejected by egress policy (SSRF guard)."""
21
+
22
+
23
+ def _port_for_url(parsed) -> int:
24
+ if parsed.port is not None:
25
+ return parsed.port
26
+ return 443 if (parsed.scheme or "").lower() == "https" else 80
27
+
28
+
29
+ def _stream_getaddrinfo_or_raise(host: str, port: int) -> list[tuple]:
30
+ try:
31
+ return socket.getaddrinfo(
32
+ host, port, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
33
+ )
34
+ except OSError as exc:
35
+ raise WebFetchEgressViolation(
36
+ f"Could not resolve host {host!r}: {exc}"
37
+ ) from exc
38
+
39
+
40
+ def get_validated_stream_addrinfos_for_egress(
41
+ url: str, policy: WebFetchEgressPolicy
42
+ ) -> list[tuple]:
43
+ """Resolve and validate a URL for web_fetch, returning getaddrinfo rows for pinning.
44
+
45
+ Each HTTP connect pins to only these `getaddrinfo` results so a malicious DNS
46
+ server cannot rebind to a disallowed address between resolution and the TCP
47
+ connect (used by :func:`api.web_tools.outbound._run_web_fetch`).
48
+ """
49
+ parsed = urlparse(url)
50
+ scheme = (parsed.scheme or "").lower()
51
+ if scheme not in policy.allowed_schemes:
52
+ raise WebFetchEgressViolation(
53
+ f"URL scheme {scheme!r} is not allowed for web_fetch"
54
+ )
55
+
56
+ host = parsed.hostname
57
+ if host is None or host == "":
58
+ raise WebFetchEgressViolation("web_fetch URL must include a host")
59
+
60
+ port = _port_for_url(parsed)
61
+
62
+ if policy.allow_private_network_targets:
63
+ return _stream_getaddrinfo_or_raise(host, port)
64
+
65
+ host_lower = host.lower()
66
+ if host_lower == "localhost" or host_lower.endswith(".localhost"):
67
+ raise WebFetchEgressViolation("localhost targets are not allowed for web_fetch")
68
+ if host_lower.endswith(".local"):
69
+ raise WebFetchEgressViolation(".local hostnames are not allowed for web_fetch")
70
+
71
+ try:
72
+ parsed_ip = ipaddress.ip_address(host)
73
+ except ValueError:
74
+ parsed_ip = None
75
+
76
+ if parsed_ip is not None:
77
+ if not parsed_ip.is_global:
78
+ raise WebFetchEgressViolation(
79
+ f"Non-public IP host {host!r} is not allowed for web_fetch"
80
+ )
81
+ return _stream_getaddrinfo_or_raise(host, port)
82
+
83
+ infos = _stream_getaddrinfo_or_raise(host, port)
84
+ for *_, sockaddr in infos:
85
+ addr = sockaddr[0]
86
+ try:
87
+ resolved = ipaddress.ip_address(addr)
88
+ except ValueError:
89
+ continue
90
+ if not resolved.is_global:
91
+ raise WebFetchEgressViolation(
92
+ f"Host {host!r} resolves to a non-public address ({resolved})"
93
+ )
94
+ return infos
95
+
96
+
97
+ def enforce_web_fetch_egress(url: str, policy: WebFetchEgressPolicy) -> None:
98
+ """Validate ``url`` (scheme, host, and resolved addresses) for web_fetch."""
99
+ get_validated_stream_addrinfos_for_egress(url, policy)
api/web_tools/outbound.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Outbound HTTP for web_search / web_fetch (client, body caps, logging)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import socket
7
+ from collections.abc import AsyncIterator
8
+ from urllib.parse import urljoin, urlparse
9
+
10
+ import aiohttp
11
+ import httpx
12
+ from aiohttp import ClientSession, ClientTimeout, TCPConnector
13
+ from aiohttp.abc import AbstractResolver, ResolveResult
14
+ from loguru import logger
15
+
16
+ from . import constants
17
+ from .constants import (
18
+ _MAX_FETCH_CHARS,
19
+ _MAX_SEARCH_RESULTS,
20
+ _REDIRECT_RESPONSE_BODY_CAP_BYTES,
21
+ _REQUEST_TIMEOUT_S,
22
+ _WEB_FETCH_REDIRECT_STATUSES,
23
+ _WEB_TOOL_HTTP_HEADERS,
24
+ )
25
+ from .egress import (
26
+ WebFetchEgressPolicy,
27
+ WebFetchEgressViolation,
28
+ get_validated_stream_addrinfos_for_egress,
29
+ )
30
+ from .parsers import HTMLTextParser, SearchResultParser
31
+
32
+
33
+ def _safe_public_host_for_logs(url: str) -> str:
34
+ host = urlparse(url).hostname or ""
35
+ return host[:253]
36
+
37
+
38
+ def _log_web_tool_failure(
39
+ tool_name: str,
40
+ error: BaseException,
41
+ *,
42
+ fetch_url: str | None = None,
43
+ ) -> None:
44
+ exc_type = type(error).__name__
45
+ if isinstance(error, WebFetchEgressViolation):
46
+ host = _safe_public_host_for_logs(fetch_url) if fetch_url else ""
47
+ logger.warning(
48
+ "web_tool_egress_rejected tool={} exc_type={} host={!r}",
49
+ tool_name,
50
+ exc_type,
51
+ host,
52
+ )
53
+ return
54
+ if tool_name == "web_fetch" and fetch_url:
55
+ logger.warning(
56
+ "web_tool_failure tool={} exc_type={} host={!r}",
57
+ tool_name,
58
+ exc_type,
59
+ _safe_public_host_for_logs(fetch_url),
60
+ )
61
+ else:
62
+ logger.warning("web_tool_failure tool={} exc_type={}", tool_name, exc_type)
63
+
64
+
65
+ def _web_tool_client_error_summary(
66
+ tool_name: str,
67
+ error: BaseException,
68
+ *,
69
+ verbose: bool,
70
+ ) -> str:
71
+ if verbose:
72
+ return f"{tool_name} failed: {type(error).__name__}"
73
+ return "Web tool request failed."
74
+
75
+
76
+ async def _iter_response_body_under_cap(
77
+ response: httpx.Response, max_bytes: int
78
+ ) -> AsyncIterator[bytes]:
79
+ if max_bytes <= 0:
80
+ return
81
+ received = 0
82
+ async for chunk in response.aiter_bytes(chunk_size=65_536):
83
+ if received >= max_bytes:
84
+ break
85
+ remaining = max_bytes - received
86
+ if len(chunk) <= remaining:
87
+ received += len(chunk)
88
+ yield chunk
89
+ if received >= max_bytes:
90
+ break
91
+ else:
92
+ yield chunk[:remaining]
93
+ break
94
+
95
+
96
+ async def _drain_response_body_capped(response: httpx.Response, max_bytes: int) -> None:
97
+ async for _ in _iter_response_body_under_cap(response, max_bytes):
98
+ pass
99
+
100
+
101
+ async def _read_response_body_capped(response: httpx.Response, max_bytes: int) -> bytes:
102
+ return b"".join(
103
+ [piece async for piece in _iter_response_body_under_cap(response, max_bytes)]
104
+ )
105
+
106
+
107
+ _NUMERIC_RESOLVE_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
108
+ _NAME_RESOLVE_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
109
+
110
+
111
+ def getaddrinfo_rows_to_resolve_results(
112
+ host: str, addrinfos: list[tuple]
113
+ ) -> list[ResolveResult]:
114
+ """Map :func:`socket.getaddrinfo` rows to aiohttp :class:`ResolveResult` (ThreadedResolver logic)."""
115
+ out: list[ResolveResult] = []
116
+ for family, _type, proto, _canon, sockaddr in addrinfos:
117
+ if family == socket.AF_INET6:
118
+ if len(sockaddr) < 3:
119
+ continue
120
+ if sockaddr[3]:
121
+ resolved_host, port = socket.getnameinfo(sockaddr, _NAME_RESOLVE_FLAGS)
122
+ else:
123
+ resolved_host, port = sockaddr[:2]
124
+ else:
125
+ assert family == socket.AF_INET, family
126
+ resolved_host, port = sockaddr[0], sockaddr[1]
127
+ resolved_host = str(resolved_host)
128
+ port = int(port)
129
+ out.append(
130
+ ResolveResult(
131
+ hostname=host,
132
+ host=resolved_host,
133
+ port=int(port),
134
+ family=family,
135
+ proto=proto,
136
+ flags=_NUMERIC_RESOLVE_FLAGS,
137
+ )
138
+ )
139
+ return out
140
+
141
+
142
+ class _PinnedEgressStaticResolver(AbstractResolver):
143
+ """Return only pre-validated :class:`ResolveResult` for the outbound request."""
144
+
145
+ def __init__(self, results: list[ResolveResult]) -> None:
146
+ self._results = results
147
+
148
+ async def resolve(
149
+ self, host: str, port: int = 0, family: int = socket.AF_INET
150
+ ) -> list[ResolveResult]:
151
+ return self._results
152
+
153
+ async def close(self) -> None: # pragma: no cover - aiohttp contract
154
+ return
155
+
156
+
157
+ async def _read_aiohttp_body_capped(
158
+ response: aiohttp.ClientResponse, max_bytes: int
159
+ ) -> bytes:
160
+ received = 0
161
+ parts: list[bytes] = []
162
+ async for chunk in response.content.iter_chunked(65_536):
163
+ if received >= max_bytes:
164
+ break
165
+ remaining = max_bytes - received
166
+ if len(chunk) <= remaining:
167
+ received += len(chunk)
168
+ parts.append(chunk)
169
+ else:
170
+ parts.append(chunk[:remaining])
171
+ break
172
+ return b"".join(parts)
173
+
174
+
175
+ async def _drain_aiohttp_body_capped(
176
+ response: aiohttp.ClientResponse, max_bytes: int
177
+ ) -> None:
178
+ if max_bytes <= 0:
179
+ return
180
+ received = 0
181
+ async for chunk in response.content.iter_chunked(65_536):
182
+ received += len(chunk)
183
+ if received >= max_bytes:
184
+ break
185
+
186
+
187
+ async def _run_web_search(query: str) -> list[dict[str, str]]:
188
+ async with (
189
+ httpx.AsyncClient(
190
+ timeout=_REQUEST_TIMEOUT_S,
191
+ follow_redirects=True,
192
+ headers=_WEB_TOOL_HTTP_HEADERS,
193
+ ) as client,
194
+ client.stream(
195
+ "GET",
196
+ "https://lite.duckduckgo.com/lite/",
197
+ params={"q": query},
198
+ ) as response,
199
+ ):
200
+ response.raise_for_status()
201
+ body_bytes = await _read_response_body_capped(
202
+ response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
203
+ )
204
+ text = body_bytes.decode("utf-8", errors="replace")
205
+ parser = SearchResultParser()
206
+ parser.feed(text)
207
+ return parser.results[:_MAX_SEARCH_RESULTS]
208
+
209
+
210
+ async def _run_web_fetch(url: str, egress: WebFetchEgressPolicy) -> dict[str, str]:
211
+ """Fetch URL with manual redirects; each hop is DNS-pinned to validated addresses."""
212
+ current_url = url
213
+ redirect_hops = 0
214
+ timeout = ClientTimeout(total=_REQUEST_TIMEOUT_S)
215
+
216
+ while True:
217
+ addr_infos = await asyncio.to_thread(
218
+ get_validated_stream_addrinfos_for_egress, current_url, egress
219
+ )
220
+ host = urlparse(current_url).hostname or ""
221
+ results = getaddrinfo_rows_to_resolve_results(host, addr_infos)
222
+ resolver = _PinnedEgressStaticResolver(results)
223
+ connector = TCPConnector(
224
+ resolver=resolver,
225
+ force_close=True,
226
+ )
227
+ try:
228
+ async with (
229
+ ClientSession(
230
+ timeout=timeout,
231
+ headers=_WEB_TOOL_HTTP_HEADERS,
232
+ connector=connector,
233
+ ) as session,
234
+ session.get(current_url, allow_redirects=False) as response,
235
+ ):
236
+ if response.status in _WEB_FETCH_REDIRECT_STATUSES:
237
+ await _drain_aiohttp_body_capped(
238
+ response, _REDIRECT_RESPONSE_BODY_CAP_BYTES
239
+ )
240
+ if redirect_hops >= constants._MAX_WEB_FETCH_REDIRECTS:
241
+ raise WebFetchEgressViolation(
242
+ "web_fetch exceeded maximum redirects "
243
+ f"({constants._MAX_WEB_FETCH_REDIRECTS})"
244
+ )
245
+ location = response.headers.get("location")
246
+ if not location or not location.strip():
247
+ raise WebFetchEgressViolation(
248
+ "web_fetch redirect response missing Location header"
249
+ )
250
+ current_url = urljoin(str(response.url), location.strip())
251
+ redirect_hops += 1
252
+ continue
253
+ response.raise_for_status()
254
+ content_type = response.headers.get("content-type", "text/plain")
255
+ final_url = str(response.url)
256
+ encoding = response.get_encoding() or "utf-8"
257
+ body_bytes = await _read_aiohttp_body_capped(
258
+ response, constants._MAX_WEB_FETCH_RESPONSE_BYTES
259
+ )
260
+ finally:
261
+ await connector.close()
262
+
263
+ break
264
+
265
+ text = body_bytes.decode(encoding, errors="replace")
266
+ title = final_url
267
+ data = text
268
+ if "html" in content_type.lower():
269
+ parser = HTMLTextParser()
270
+ parser.feed(text)
271
+ title = parser.title or final_url
272
+ data = "\n".join(parser.text_parts)
273
+ return {
274
+ "url": final_url,
275
+ "title": title,
276
+ "media_type": "text/plain",
277
+ "data": data[:_MAX_FETCH_CHARS],
278
+ }
api/web_tools/parsers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HTML parsing for web_search / web_fetch."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import html
6
+ import re
7
+ from html.parser import HTMLParser
8
+ from typing import Any
9
+ from urllib.parse import parse_qs, unquote, urlparse
10
+
11
+
12
+ class SearchResultParser(HTMLParser):
13
+ """DuckDuckGo lite HTML: extract result links and titles."""
14
+
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+ self.results: list[dict[str, str]] = []
18
+ self._href: str | None = None
19
+ self._title_parts: list[str] = []
20
+
21
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
22
+ if tag != "a":
23
+ return
24
+ href = dict(attrs).get("href")
25
+ if not href or "uddg=" not in href:
26
+ return
27
+ parsed = urlparse(href)
28
+ query = parse_qs(parsed.query)
29
+ uddg = query.get("uddg", [""])[0]
30
+ if not uddg:
31
+ return
32
+ self._href = unquote(uddg)
33
+ self._title_parts = []
34
+
35
+ def handle_data(self, data: str) -> None:
36
+ if self._href is not None:
37
+ self._title_parts.append(data)
38
+
39
+ def handle_endtag(self, tag: str) -> None:
40
+ if tag != "a" or self._href is None:
41
+ return
42
+ title = " ".join("".join(self._title_parts).split())
43
+ if title and not any(result["url"] == self._href for result in self.results):
44
+ self.results.append({"title": html.unescape(title), "url": self._href})
45
+ self._href = None
46
+ self._title_parts = []
47
+
48
+
49
+ class HTMLTextParser(HTMLParser):
50
+ """Strip scripts/styles and collect visible text + title for fetch previews."""
51
+
52
+ def __init__(self) -> None:
53
+ super().__init__()
54
+ self.title = ""
55
+ self.text_parts: list[str] = []
56
+ self._in_title = False
57
+ self._skip_depth = 0
58
+
59
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
60
+ if tag in {"script", "style", "noscript"}:
61
+ self._skip_depth += 1
62
+ elif tag == "title":
63
+ self._in_title = True
64
+
65
+ def handle_endtag(self, tag: str) -> None:
66
+ if tag in {"script", "style", "noscript"} and self._skip_depth:
67
+ self._skip_depth -= 1
68
+ elif tag == "title":
69
+ self._in_title = False
70
+
71
+ def handle_data(self, data: str) -> None:
72
+ text = " ".join(data.split())
73
+ if not text:
74
+ return
75
+ if self._in_title:
76
+ self.title = f"{self.title} {text}".strip()
77
+ elif not self._skip_depth:
78
+ self.text_parts.append(text)
79
+
80
+
81
+ def content_text(content: Any) -> str:
82
+ if isinstance(content, str):
83
+ return content
84
+ if isinstance(content, list):
85
+ parts = []
86
+ for item in content:
87
+ if isinstance(item, dict):
88
+ parts.append(str(item.get("text", "")))
89
+ else:
90
+ parts.append(str(getattr(item, "text", "")))
91
+ return "\n".join(part for part in parts if part)
92
+ return str(content)
93
+
94
+
95
+ def extract_query(text: str) -> str:
96
+ match = re.search(r"query:\s*(.+)", text, flags=re.IGNORECASE | re.DOTALL)
97
+ if match:
98
+ return match.group(1).strip().strip("\"'")
99
+ return text.strip()
100
+
101
+
102
+ def extract_url(text: str) -> str:
103
+ match = re.search(r"https?://\S+", text)
104
+ return match.group(0).rstrip(").,]") if match else text.strip()
api/web_tools/request.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detect forced Anthropic web server tool requests."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from api.models.anthropic import MessagesRequest, Tool
6
+
7
+
8
+ def request_text(request: MessagesRequest) -> str:
9
+ """Join all user/assistant message content into one string for tool input parsing."""
10
+ from .parsers import content_text
11
+
12
+ return "\n".join(content_text(message.content) for message in request.messages)
13
+
14
+
15
+ def forced_tool_turn_text(request: MessagesRequest) -> str:
16
+ """Text for parsing forced server-tool inputs: latest user turn only (avoids stale history)."""
17
+ if not request.messages:
18
+ return ""
19
+
20
+ from .parsers import content_text
21
+
22
+ for message in reversed(request.messages):
23
+ if message.role == "user":
24
+ return content_text(message.content)
25
+ return ""
26
+
27
+
28
+ def forced_server_tool_name(request: MessagesRequest) -> str | None:
29
+ """Return web_search or web_fetch only when tool_choice forces that server tool."""
30
+ tc = request.tool_choice
31
+ if not isinstance(tc, dict):
32
+ return None
33
+ if tc.get("type") != "tool":
34
+ return None
35
+ name = tc.get("name")
36
+ if name in {"web_search", "web_fetch"}:
37
+ return str(name)
38
+ return None
39
+
40
+
41
+ def has_tool_named(request: MessagesRequest, name: str) -> bool:
42
+ return any(tool.name == name for tool in request.tools or [])
43
+
44
+
45
+ def is_web_server_tool_request(request: MessagesRequest) -> bool:
46
+ """True when the client forces a web server tool via tool_choice (not merely listed)."""
47
+ forced = forced_server_tool_name(request)
48
+ if forced is None:
49
+ return False
50
+ return has_tool_named(request, forced)
51
+
52
+
53
+ def is_anthropic_server_tool_definition(tool: Tool) -> bool:
54
+ """Whether ``tool`` refers to an Anthropic server tool (web_search / web_fetch family)."""
55
+ name = (tool.name or "").strip()
56
+ if name in ("web_search", "web_fetch"):
57
+ return True
58
+ typ = tool.type
59
+ if isinstance(typ, str):
60
+ return typ.startswith("web_search") or typ.startswith("web_fetch")
61
+ return False
62
+
63
+
64
+ def has_listed_anthropic_server_tools(request: MessagesRequest) -> bool:
65
+ """True when tools include web_search / web_fetch-style entries (listed, forced or not)."""
66
+ return any(is_anthropic_server_tool_definition(t) for t in (request.tools or []))
67
+
68
+
69
+ def openai_chat_upstream_server_tool_error(
70
+ request: MessagesRequest, *, web_tools_enabled: bool
71
+ ) -> str | None:
72
+ """Return a user-facing error when OpenAI Chat upstream cannot satisfy server-tool semantics."""
73
+ forced = forced_server_tool_name(request)
74
+ if forced and not web_tools_enabled:
75
+ return (
76
+ f"tool_choice forces Anthropic server tool {forced!r}, but local web server tools are "
77
+ "disabled (ENABLE_WEB_SERVER_TOOLS=false). Enable them to use this tool."
78
+ )
79
+ if not forced and has_listed_anthropic_server_tools(request):
80
+ return (
81
+ "OpenAI Chat upstreams (NVIDIA NIM) cannot use listed Anthropic server tools "
82
+ "(web_search / web_fetch) without the local web server tool handler. "
83
+ "Set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
84
+ "tool_choice, or remove these tools from the request."
85
+ )
86
+ return None
api/web_tools/streaming.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SSE streaming for local web_search / web_fetch server tool results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from collections.abc import AsyncIterator
7
+ from datetime import UTC, datetime
8
+ from typing import Any
9
+
10
+ from api.models.anthropic import MessagesRequest
11
+ from core.anthropic.server_tool_sse import (
12
+ SERVER_TOOL_USE,
13
+ WEB_FETCH_TOOL_ERROR,
14
+ WEB_FETCH_TOOL_RESULT,
15
+ WEB_SEARCH_TOOL_RESULT,
16
+ WEB_SEARCH_TOOL_RESULT_ERROR,
17
+ )
18
+ from core.anthropic.sse import format_sse_event
19
+
20
+ from .constants import _MAX_FETCH_CHARS
21
+ from .egress import WebFetchEgressPolicy
22
+ from .parsers import extract_query, extract_url
23
+ from .request import (
24
+ forced_server_tool_name,
25
+ forced_tool_turn_text,
26
+ has_tool_named,
27
+ )
28
+
29
+
30
+ def _search_summary(query: str, results: list[dict[str, str]]) -> str:
31
+ if not results:
32
+ return f"No web search results found for: {query}"
33
+ lines = [f"Search results for: {query}"]
34
+ for index, result in enumerate(results, start=1):
35
+ lines.append(f"{index}. {result['title']}\n{result['url']}")
36
+ return "\n\n".join(lines)
37
+
38
+
39
+ async def stream_web_server_tool_response(
40
+ request: MessagesRequest,
41
+ input_tokens: int,
42
+ *,
43
+ web_fetch_egress: WebFetchEgressPolicy,
44
+ verbose_client_errors: bool = False,
45
+ ) -> AsyncIterator[str]:
46
+ """Stream a minimal Anthropic-shaped turn for forced `web_search` / `web_fetch` (local fallback).
47
+
48
+ When `ENABLE_WEB_SERVER_TOOLS` is on, this is a proxy-side execution path — not a full
49
+ hosted Anthropic citation or encrypted-content pipeline.
50
+ """
51
+ from . import outbound
52
+ tool_name = forced_server_tool_name(request)
53
+ if tool_name is None or not has_tool_named(request, tool_name):
54
+ return
55
+
56
+ text = forced_tool_turn_text(request)
57
+ message_id = f"msg_{uuid.uuid4()}"
58
+ tool_id = f"srvtoolu_{uuid.uuid4().hex}"
59
+ usage_key = (
60
+ "web_search_requests" if tool_name == "web_search" else "web_fetch_requests"
61
+ )
62
+ tool_input = (
63
+ {"query": extract_query(text)}
64
+ if tool_name == "web_search"
65
+ else {"url": extract_url(text)}
66
+ )
67
+ _result_block_for_tool = {
68
+ "web_search": WEB_SEARCH_TOOL_RESULT,
69
+ "web_fetch": WEB_FETCH_TOOL_RESULT,
70
+ }
71
+ _error_payload_type_for_tool = {
72
+ "web_search": WEB_SEARCH_TOOL_RESULT_ERROR,
73
+ "web_fetch": WEB_FETCH_TOOL_ERROR,
74
+ }
75
+
76
+ yield format_sse_event(
77
+ "message_start",
78
+ {
79
+ "type": "message_start",
80
+ "message": {
81
+ "id": message_id,
82
+ "type": "message",
83
+ "role": "assistant",
84
+ "content": [],
85
+ "model": request.model,
86
+ "stop_reason": None,
87
+ "stop_sequence": None,
88
+ "usage": {"input_tokens": input_tokens, "output_tokens": 1},
89
+ },
90
+ },
91
+ )
92
+ yield format_sse_event(
93
+ "content_block_start",
94
+ {
95
+ "type": "content_block_start",
96
+ "index": 0,
97
+ "content_block": {
98
+ "type": SERVER_TOOL_USE,
99
+ "id": tool_id,
100
+ "name": tool_name,
101
+ "input": tool_input,
102
+ },
103
+ },
104
+ )
105
+ yield format_sse_event(
106
+ "content_block_stop", {"type": "content_block_stop", "index": 0}
107
+ )
108
+
109
+ try:
110
+ if tool_name == "web_search":
111
+ query = str(tool_input["query"])
112
+ results = await outbound._run_web_search(query)
113
+ result_content: Any = [
114
+ {
115
+ "type": "web_search_result",
116
+ "title": result["title"],
117
+ "url": result["url"],
118
+ }
119
+ for result in results
120
+ ]
121
+ summary = _search_summary(query, results)
122
+ result_block_type = WEB_SEARCH_TOOL_RESULT
123
+ else:
124
+ fetched = await outbound._run_web_fetch(
125
+ str(tool_input["url"]), web_fetch_egress
126
+ )
127
+ result_content = {
128
+ "type": "web_fetch_result",
129
+ "url": fetched["url"],
130
+ "content": {
131
+ "type": "document",
132
+ "source": {
133
+ "type": "text",
134
+ "media_type": fetched["media_type"],
135
+ "data": fetched["data"],
136
+ },
137
+ "title": fetched["title"],
138
+ "citations": {"enabled": True},
139
+ },
140
+ "retrieved_at": datetime.now(UTC).isoformat(),
141
+ }
142
+ summary = fetched["data"][:_MAX_FETCH_CHARS]
143
+ result_block_type = WEB_FETCH_TOOL_RESULT
144
+ except Exception as error:
145
+ fetch_url = str(tool_input["url"]) if tool_name == "web_fetch" else None
146
+ outbound._log_web_tool_failure(tool_name, error, fetch_url=fetch_url)
147
+ result_block_type = _result_block_for_tool[tool_name]
148
+ result_content = {
149
+ "type": _error_payload_type_for_tool[tool_name],
150
+ "error_code": "unavailable",
151
+ }
152
+ summary = outbound._web_tool_client_error_summary(
153
+ tool_name, error, verbose=verbose_client_errors
154
+ )
155
+
156
+ output_tokens = max(1, len(summary) // 4)
157
+
158
+ yield format_sse_event(
159
+ "content_block_start",
160
+ {
161
+ "type": "content_block_start",
162
+ "index": 1,
163
+ "content_block": {
164
+ "type": result_block_type,
165
+ "tool_use_id": tool_id,
166
+ "content": result_content,
167
+ },
168
+ },
169
+ )
170
+ yield format_sse_event(
171
+ "content_block_stop", {"type": "content_block_stop", "index": 1}
172
+ )
173
+ # Model-facing summary: stream as normal text deltas (CLI/transcript code reads `text_delta`,
174
+ # not eager `text` on `content_block_start`).
175
+ yield format_sse_event(
176
+ "content_block_start",
177
+ {
178
+ "type": "content_block_start",
179
+ "index": 2,
180
+ "content_block": {"type": "text", "text": ""},
181
+ },
182
+ )
183
+ yield format_sse_event(
184
+ "content_block_delta",
185
+ {
186
+ "type": "content_block_delta",
187
+ "index": 2,
188
+ "delta": {"type": "text_delta", "text": summary},
189
+ },
190
+ )
191
+ yield format_sse_event(
192
+ "content_block_stop", {"type": "content_block_stop", "index": 2}
193
+ )
194
+ yield format_sse_event(
195
+ "message_delta",
196
+ {
197
+ "type": "message_delta",
198
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
199
+ "usage": {
200
+ "input_tokens": input_tokens,
201
+ "output_tokens": output_tokens,
202
+ "server_tool_use": {usage_key: 1},
203
+ },
204
+ },
205
+ )
206
+ yield format_sse_event("message_stop", {"type": "message_stop"})
cli/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """CLI integration for Claude Code."""
2
+
3
+ from .manager import CLISessionManager
4
+ from .session import CLISession
5
+
6
+ __all__ = ["CLISession", "CLISessionManager"]
cli/entrypoints.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entry points for the installed package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+
8
+ def _load_env_template() -> str:
9
+ """Load the canonical root env template from package resources or source."""
10
+ import importlib.resources
11
+
12
+ packaged = importlib.resources.files("cli").joinpath("env.example")
13
+ if packaged.is_file():
14
+ return packaged.read_text("utf-8")
15
+
16
+ source_template = Path(__file__).resolve().parents[1] / ".env.example"
17
+ if source_template.is_file():
18
+ return source_template.read_text(encoding="utf-8")
19
+
20
+ raise FileNotFoundError("Could not find bundled or source .env.example template.")
21
+
22
+
23
+ def serve() -> None:
24
+ """Start the FastAPI server (registered as `free-claude-code` script)."""
25
+ import uvicorn
26
+
27
+ from cli.process_registry import kill_all_best_effort
28
+ from config.settings import get_settings
29
+
30
+ settings = get_settings()
31
+ try:
32
+ uvicorn.run(
33
+ "api.app:create_asgi_app",
34
+ factory=True,
35
+ host=settings.host,
36
+ port=settings.port,
37
+ log_level="debug",
38
+ timeout_graceful_shutdown=5,
39
+ )
40
+ finally:
41
+ kill_all_best_effort()
42
+
43
+
44
+ def init() -> None:
45
+ """Scaffold config at ~/.config/free-claude-code/.env (registered as `fcc-init`)."""
46
+ config_dir = Path.home() / ".config" / "free-claude-code"
47
+ env_file = config_dir / ".env"
48
+
49
+ if env_file.exists():
50
+ print(f"Config already exists at {env_file}")
51
+ print("Delete it first if you want to reset to defaults.")
52
+ return
53
+
54
+ config_dir.mkdir(parents=True, exist_ok=True)
55
+ template = _load_env_template()
56
+ env_file.write_text(template, encoding="utf-8")
57
+ print(f"Config created at {env_file}")
58
+ print(
59
+ "Edit it to set your API keys and model preferences, then run: free-claude-code"
60
+ )
cli/manager.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI Session Manager for Multi-Instance Claude CLI Support
3
+
4
+ Manages a pool of CLISession instances, each handling one conversation.
5
+ This enables true parallel processing where multiple conversations run
6
+ simultaneously in separate CLI processes.
7
+ """
8
+
9
+ import asyncio
10
+ import uuid
11
+
12
+ from loguru import logger
13
+
14
+ from .session import CLISession
15
+
16
+
17
+ class CLISessionManager:
18
+ """
19
+ Manages multiple CLISession instances for parallel conversation processing.
20
+
21
+ Each new conversation gets its own CLISession with its own subprocess.
22
+ Replies to existing conversations reuse the same CLISession instance.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ workspace_path: str,
28
+ api_url: str,
29
+ allowed_dirs: list[str] | None = None,
30
+ plans_directory: str | None = None,
31
+ claude_bin: str = "claude",
32
+ *,
33
+ log_raw_cli_diagnostics: bool = False,
34
+ log_messaging_error_details: bool = False,
35
+ ):
36
+ """
37
+ Initialize the session manager.
38
+
39
+ Args:
40
+ workspace_path: Working directory for CLI processes
41
+ api_url: API URL for the proxy
42
+ allowed_dirs: Directories the CLI is allowed to access
43
+ plans_directory: Directory for Claude Code CLI plan files (passed via --settings)
44
+ """
45
+ self.workspace = workspace_path
46
+ self.api_url = api_url
47
+ self.allowed_dirs = allowed_dirs or []
48
+ self.plans_directory = plans_directory
49
+ self.claude_bin = claude_bin
50
+ self._log_raw_cli_diagnostics = log_raw_cli_diagnostics
51
+ self._log_messaging_error_details = log_messaging_error_details
52
+
53
+ self._sessions: dict[str, CLISession] = {}
54
+ self._pending_sessions: dict[str, CLISession] = {}
55
+ self._temp_to_real: dict[str, str] = {}
56
+ self._real_to_temp: dict[str, str] = {}
57
+ self._lock = asyncio.Lock()
58
+
59
+ logger.info("CLISessionManager initialized")
60
+
61
+ async def get_or_create_session(
62
+ self, session_id: str | None = None
63
+ ) -> tuple[CLISession, str, bool]:
64
+ """
65
+ Get an existing session or create a new one.
66
+
67
+ Returns:
68
+ Tuple of (CLISession instance, session_id, is_new_session)
69
+ """
70
+ async with self._lock:
71
+ if session_id:
72
+ lookup_id = self._temp_to_real.get(session_id, session_id)
73
+
74
+ if lookup_id in self._sessions:
75
+ return self._sessions[lookup_id], lookup_id, False
76
+ if lookup_id in self._pending_sessions:
77
+ return self._pending_sessions[lookup_id], lookup_id, False
78
+
79
+ temp_id = session_id if session_id else f"pending_{uuid.uuid4().hex[:8]}"
80
+
81
+ new_session = CLISession(
82
+ workspace_path=self.workspace,
83
+ api_url=self.api_url,
84
+ allowed_dirs=self.allowed_dirs,
85
+ plans_directory=self.plans_directory,
86
+ claude_bin=self.claude_bin,
87
+ log_raw_cli_diagnostics=self._log_raw_cli_diagnostics,
88
+ )
89
+ self._pending_sessions[temp_id] = new_session
90
+ logger.info(f"Created new session: {temp_id}")
91
+
92
+ return new_session, temp_id, True
93
+
94
+ async def register_real_session_id(
95
+ self, temp_id: str, real_session_id: str
96
+ ) -> bool:
97
+ """Register the real session ID from CLI output."""
98
+ async with self._lock:
99
+ if temp_id not in self._pending_sessions:
100
+ logger.warning(f"Temp session {temp_id} not found")
101
+ return False
102
+
103
+ session = self._pending_sessions.pop(temp_id)
104
+ self._sessions[real_session_id] = session
105
+ self._temp_to_real[temp_id] = real_session_id
106
+ self._real_to_temp[real_session_id] = temp_id
107
+
108
+ logger.info(f"Registered session: {temp_id} -> {real_session_id}")
109
+ return True
110
+
111
+ async def remove_session(self, session_id: str) -> bool:
112
+ """Remove a session from the manager."""
113
+ async with self._lock:
114
+ if session_id in self._pending_sessions:
115
+ session = self._pending_sessions.pop(session_id)
116
+ await session.stop()
117
+ return True
118
+
119
+ if session_id in self._sessions:
120
+ session = self._sessions.pop(session_id)
121
+ await session.stop()
122
+ temp_id = self._real_to_temp.pop(session_id, None)
123
+ if temp_id is not None:
124
+ self._temp_to_real.pop(temp_id, None)
125
+ return True
126
+
127
+ return False
128
+
129
+ async def stop_all(self):
130
+ """Stop all sessions."""
131
+ async with self._lock:
132
+ all_sessions = list(self._sessions.values()) + list(
133
+ self._pending_sessions.values()
134
+ )
135
+ for session in all_sessions:
136
+ try:
137
+ await session.stop()
138
+ except Exception as e:
139
+ if self._log_messaging_error_details:
140
+ logger.error(
141
+ "Error stopping session: {}: {}",
142
+ type(e).__name__,
143
+ e,
144
+ )
145
+ else:
146
+ logger.error(
147
+ "Error stopping session: exc_type={}",
148
+ type(e).__name__,
149
+ )
150
+
151
+ self._sessions.clear()
152
+ self._pending_sessions.clear()
153
+ self._temp_to_real.clear()
154
+ self._real_to_temp.clear()
155
+ logger.info("All sessions stopped")
156
+
157
+ def get_stats(self) -> dict:
158
+ """Get session statistics."""
159
+ return {
160
+ "active_sessions": len(self._sessions),
161
+ "pending_sessions": len(self._pending_sessions),
162
+ "busy_count": sum(1 for s in self._sessions.values() if s.is_busy),
163
+ }
cli/process_registry.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Track and clean up spawned CLI subprocesses.
2
+
3
+ This is a safety net for cases where the server is interrupted (Ctrl+C) and the
4
+ FastAPI lifespan cleanup doesn't run to completion. We only track processes we
5
+ spawn so we don't accidentally kill unrelated system processes.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import atexit
11
+ import os
12
+ import subprocess
13
+ import threading
14
+
15
+ from loguru import logger
16
+
17
+ _lock = threading.Lock()
18
+ _pids: set[int] = set()
19
+ _atexit_registered = False
20
+
21
+
22
+ def ensure_atexit_registered() -> None:
23
+ global _atexit_registered
24
+ with _lock:
25
+ if _atexit_registered:
26
+ return
27
+ atexit.register(kill_all_best_effort)
28
+ _atexit_registered = True
29
+
30
+
31
+ def register_pid(pid: int) -> None:
32
+ if not pid:
33
+ return
34
+ ensure_atexit_registered()
35
+ with _lock:
36
+ _pids.add(int(pid))
37
+
38
+
39
+ def unregister_pid(pid: int) -> None:
40
+ if not pid:
41
+ return
42
+ with _lock:
43
+ _pids.discard(int(pid))
44
+
45
+
46
+ def kill_all_best_effort() -> None:
47
+ """Kill any still-running registered pids (best-effort)."""
48
+ with _lock:
49
+ pids = list(_pids)
50
+ _pids.clear()
51
+
52
+ if not pids:
53
+ return
54
+
55
+ if os.name == "nt":
56
+ for pid in pids:
57
+ try:
58
+ # /T kills child processes, /F forces termination.
59
+ subprocess.run(
60
+ ["taskkill", "/PID", str(pid), "/T", "/F"],
61
+ stdout=subprocess.DEVNULL,
62
+ stderr=subprocess.DEVNULL,
63
+ check=False,
64
+ )
65
+ except Exception as e:
66
+ logger.debug("process_registry: taskkill failed pid=%s: %s", pid, e)
67
+ return
68
+
69
+ # Best-effort fallback for non-Windows.
70
+ for pid in pids:
71
+ try:
72
+ os.kill(pid, 9)
73
+ except Exception as e:
74
+ logger.debug("process_registry: kill failed pid=%s: %s", pid, e)