NiWaRe commited on
Commit
a2dc155
·
1 Parent(s): 04c12c4

add api key based http auth

Browse files
app.py CHANGED
@@ -23,7 +23,7 @@ os.environ["WANDB_SILENT"] = "True"
23
  os.environ["WEAVE_SILENT"] = "True"
24
 
25
  from fastapi import FastAPI
26
- from fastapi.responses import HTMLResponse
27
  from fastapi.middleware.cors import CORSMiddleware
28
  from mcp.server.fastmcp import FastMCP
29
 
@@ -37,6 +37,13 @@ from wandb_mcp_server.server import (
37
  ServerMCPArgs
38
  )
39
 
 
 
 
 
 
 
 
40
  # Configure logging
41
  logging.basicConfig(
42
  level=logging.INFO,
@@ -61,15 +68,17 @@ args = ServerMCPArgs(
61
  )
62
 
63
  wandb_configured = False
64
- try:
65
- api_key = validate_and_get_api_key(args)
66
- setup_wandb_login(api_key)
67
- initialize_weave_tracing()
68
- wandb_configured = True
69
- logger.info("W&B API configured successfully")
70
- except ValueError as e:
71
- logger.warning(f"W&B API key not configured: {e}")
72
- logger.warning("Server will start but W&B operations will fail")
 
 
73
 
74
  # Create the MCP server
75
  logger.info("Creating W&B MCP server...")
@@ -103,12 +112,24 @@ app.add_middleware(
103
  allow_headers=["*"],
104
  )
105
 
 
 
 
 
 
 
106
  # Add custom routes
107
  @app.get("/", response_class=HTMLResponse)
108
  async def index():
109
  """Serve the landing page."""
110
  return INDEX_HTML_CONTENT
111
 
 
 
 
 
 
 
112
  @app.get("/health")
113
  async def health():
114
  """Health check endpoint."""
@@ -119,11 +140,14 @@ async def health():
119
  except:
120
  tool_count = 0
121
 
 
 
122
  return {
123
  "status": "healthy",
124
  "service": "wandb-mcp-server",
125
  "wandb_configured": wandb_configured,
126
- "tools_registered": tool_count
 
127
  }
128
 
129
  # Mount the MCP streamable HTTP app
 
23
  os.environ["WEAVE_SILENT"] = "True"
24
 
25
  from fastapi import FastAPI
26
+ from fastapi.responses import HTMLResponse, JSONResponse
27
  from fastapi.middleware.cors import CORSMiddleware
28
  from mcp.server.fastmcp import FastMCP
29
 
 
37
  ServerMCPArgs
38
  )
39
 
40
+ # Import authentication
41
+ from wandb_mcp_server.auth import (
42
+ mcp_auth_middleware,
43
+ create_resource_metadata_response,
44
+ MCPAuthConfig
45
+ )
46
+
47
  # Configure logging
48
  logging.basicConfig(
49
  level=logging.INFO,
 
68
  )
69
 
70
  wandb_configured = False
71
+ api_key = validate_and_get_api_key(args)
72
+ if api_key:
73
+ try:
74
+ setup_wandb_login(api_key)
75
+ initialize_weave_tracing()
76
+ wandb_configured = True
77
+ logger.info("Server W&B API key configured successfully")
78
+ except Exception as e:
79
+ logger.warning(f"Failed to configure server W&B API key: {e}")
80
+ else:
81
+ logger.info("No server W&B API key configured - clients will provide their own")
82
 
83
  # Create the MCP server
84
  logger.info("Creating W&B MCP server...")
 
112
  allow_headers=["*"],
113
  )
114
 
115
+ # Add authentication middleware for MCP endpoints
116
+ @app.middleware("http")
117
+ async def auth_middleware(request, call_next):
118
+ """Add OAuth 2.1 Bearer token authentication for MCP endpoints."""
119
+ return await mcp_auth_middleware(request, call_next)
120
+
121
  # Add custom routes
122
  @app.get("/", response_class=HTMLResponse)
123
  async def index():
124
  """Serve the landing page."""
125
  return INDEX_HTML_CONTENT
126
 
127
+ @app.get("/.well-known/oauth-protected-resource")
128
+ async def resource_metadata():
129
+ """OAuth 2.0 Protected Resource Metadata endpoint (RFC 9728)."""
130
+ config = MCPAuthConfig()
131
+ return JSONResponse(create_resource_metadata_response(config))
132
+
133
  @app.get("/health")
134
  async def health():
135
  """Health check endpoint."""
 
140
  except:
141
  tool_count = 0
142
 
143
+ auth_status = "disabled" if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true" else "enabled"
144
+
145
  return {
146
  "status": "healthy",
147
  "service": "wandb-mcp-server",
148
  "wandb_configured": wandb_configured,
149
+ "tools_registered": tool_count,
150
+ "authentication": auth_status
151
  }
152
 
153
  # Mount the MCP streamable HTTP app
src/wandb_mcp_server/auth.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication middleware for W&B MCP Server.
3
+
4
+ Implements Bearer token validation for HTTP transport as per
5
+ MCP specification: https://modelcontextprotocol.io/specification/draft/basic/authorization
6
+
7
+ Clients send their W&B API keys as Bearer tokens, which the server
8
+ then uses for all W&B operations on behalf of that client.
9
+ """
10
+
11
+ import os
12
+ import logging
13
+ import re
14
+ from typing import Optional, Dict, Any
15
+ from fastapi import HTTPException, Request, status
16
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17
+ from fastapi.responses import JSONResponse
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Bearer token security scheme
22
+ bearer_scheme = HTTPBearer(auto_error=False)
23
+
24
+
25
+ class MCPAuthConfig:
26
+ """
27
+ Configuration for MCP authentication.
28
+
29
+ For HTTP transport: Accepts any W&B API key as a Bearer token.
30
+ The server uses the client's token for all W&B operations.
31
+ """
32
+
33
+ def __init__(self):
34
+ self.resource_metadata_url = os.environ.get(
35
+ "MCP_RESOURCE_METADATA_URL",
36
+ "/.well-known/oauth-protected-resource"
37
+ )
38
+ # Point to W&B's Auth0 instance for reference
39
+ self.authorization_server = os.environ.get(
40
+ "MCP_AUTH_SERVER",
41
+ "https://wandb.auth0.com"
42
+ )
43
+
44
+
45
+ def is_valid_wandb_api_key(token: str) -> bool:
46
+ """
47
+ Check if a token looks like a valid W&B API key format.
48
+ W&B API keys are typically 40 characters of alphanumeric + some special chars.
49
+ """
50
+ if not token or len(token) < 20 or len(token) > 100:
51
+ return False
52
+ # Basic validation - W&B keys contain alphanumeric and some special characters
53
+ # This is a permissive check since W&B key format may vary
54
+ if re.match(r'^[a-zA-Z0-9_\-\.]+$', token):
55
+ return True
56
+ return False
57
+
58
+
59
+ async def validate_bearer_token(
60
+ credentials: Optional[HTTPAuthorizationCredentials],
61
+ config: MCPAuthConfig
62
+ ) -> str:
63
+ """
64
+ Validate Bearer token (W&B API key) for MCP access.
65
+
66
+ Accepts any valid-looking W&B API key. The actual validation
67
+ happens when the key is used to call W&B APIs.
68
+
69
+ Returns:
70
+ The W&B API key to use for operations
71
+
72
+ Raises:
73
+ HTTPException: 401 Unauthorized with WWW-Authenticate header
74
+ """
75
+ if not credentials or not credentials.credentials:
76
+ raise HTTPException(
77
+ status_code=status.HTTP_401_UNAUTHORIZED,
78
+ detail="Authorization required - please provide your W&B API key as a Bearer token",
79
+ headers={
80
+ "WWW-Authenticate": f'Bearer realm="W&B MCP", '
81
+ f'resource_metadata="{config.resource_metadata_url}"'
82
+ }
83
+ )
84
+
85
+ token = credentials.credentials
86
+
87
+ # Basic format validation
88
+ if not is_valid_wandb_api_key(token):
89
+ raise HTTPException(
90
+ status_code=status.HTTP_401_UNAUTHORIZED,
91
+ detail="Invalid W&B API key format. Get your key at: https://wandb.ai/authorize",
92
+ headers={
93
+ "WWW-Authenticate": f'Bearer realm="W&B MCP", '
94
+ f'error="invalid_token", '
95
+ f'resource_metadata="{config.resource_metadata_url}"'
96
+ }
97
+ )
98
+
99
+ logger.debug("Bearer token validated successfully")
100
+ return token
101
+
102
+
103
+ async def mcp_auth_middleware(request: Request, call_next):
104
+ """
105
+ FastAPI middleware for MCP authentication on HTTP transport.
106
+
107
+ Only applies to MCP endpoints (/mcp/*).
108
+ Extracts the client's W&B API key from the Bearer token and stores it
109
+ for use in W&B operations.
110
+ """
111
+ # Only apply auth to MCP endpoints
112
+ if not request.url.path.startswith("/mcp"):
113
+ return await call_next(request)
114
+
115
+ # Skip auth if explicitly disabled (development only)
116
+ if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true":
117
+ logger.warning("MCP authentication is disabled - endpoints are publicly accessible")
118
+ return await call_next(request)
119
+
120
+ config = MCPAuthConfig()
121
+
122
+ try:
123
+ # Extract bearer token from Authorization header
124
+ authorization = request.headers.get("Authorization", "")
125
+ credentials = None
126
+ if authorization.startswith("Bearer "):
127
+ credentials = HTTPAuthorizationCredentials(
128
+ scheme="Bearer",
129
+ credentials=authorization[7:] # Remove "Bearer " prefix
130
+ )
131
+
132
+ # Validate and get the W&B API key
133
+ wandb_api_key = await validate_bearer_token(credentials, config)
134
+
135
+ # Store the API key in request state for W&B operations
136
+ # The MCP tools should access this from the request context
137
+ request.state.wandb_api_key = wandb_api_key
138
+
139
+ # For now, we'll set it in environment (in production, use contextvars)
140
+ # Save the original value to restore later
141
+ original_api_key = os.environ.get("WANDB_API_KEY")
142
+ os.environ["WANDB_API_KEY"] = wandb_api_key
143
+
144
+ try:
145
+ # Continue processing
146
+ response = await call_next(request)
147
+ finally:
148
+ # Restore original environment
149
+ if original_api_key:
150
+ os.environ["WANDB_API_KEY"] = original_api_key
151
+ elif "WANDB_API_KEY" in os.environ:
152
+ del os.environ["WANDB_API_KEY"]
153
+
154
+ return response
155
+
156
+ except HTTPException as e:
157
+ # Return proper error response
158
+ return JSONResponse(
159
+ status_code=e.status_code,
160
+ content={"error": e.detail},
161
+ headers=e.headers
162
+ )
163
+ except Exception as e:
164
+ logger.error(f"Authentication error: {e}")
165
+ return JSONResponse(
166
+ status_code=status.HTTP_401_UNAUTHORIZED,
167
+ content={"error": "Authentication failed"},
168
+ headers={
169
+ "WWW-Authenticate": f'Bearer realm="W&B MCP", '
170
+ f'resource_metadata="{config.resource_metadata_url}"'
171
+ }
172
+ )
173
+
174
+
175
+ def create_resource_metadata_response(config: MCPAuthConfig) -> Dict[str, Any]:
176
+ """
177
+ Create OAuth 2.0 Protected Resource Metadata response (RFC 9728).
178
+
179
+ This tells MCP clients that we use W&B API keys as Bearer tokens.
180
+ Points to W&B's Auth0 instance where users can get their API keys.
181
+ """
182
+ return {
183
+ "resource": os.environ.get("MCP_SERVER_URL", "https://wandb-mcp-server.hf.space"),
184
+ "authorization_servers": [config.authorization_server],
185
+ "bearer_methods_supported": ["header"],
186
+ "resource_documentation": "https://github.com/wandb/wandb-mcp-server",
187
+ "authentication_note": "Use your W&B API key as a Bearer token. Get your key at https://wandb.ai/authorize",
188
+ }
src/wandb_mcp_server/mcp_tools/query_weave.py CHANGED
@@ -6,7 +6,14 @@ from wandb_mcp_server.weave_api.models import QueryResult
6
 
7
  logger = get_rich_logger(__name__)
8
 
9
- _trace_service = TraceService()
 
 
 
 
 
 
 
10
 
11
  QUERY_WEAVE_TRACES_TOOL_DESCRIPTION = """
12
  Query Weave traces, trace metadata, and trace costs with filtering and sorting options.
@@ -301,7 +308,7 @@ def query_traces(
301
  but delegates to our new implementation.
302
  """
303
  # If api_key was provided, create a new service with that key
304
- service = _trace_service
305
  if api_key:
306
  service = TraceService(
307
  api_key=api_key,
@@ -410,7 +417,7 @@ async def query_paginated_weave_traces(
410
  QueryResult: A Pydantic model containing the query results
411
  """
412
  # If api_key was provided, create a new service with that key
413
- service = _trace_service
414
  if api_key:
415
  service = TraceService(
416
  api_key=api_key,
 
6
 
7
  logger = get_rich_logger(__name__)
8
 
9
+ # Lazy load the trace service to avoid requiring API key at import time
10
+ _trace_service = None
11
+
12
+ def get_trace_service():
13
+ global _trace_service
14
+ if _trace_service is None:
15
+ _trace_service = TraceService()
16
+ return _trace_service
17
 
18
  QUERY_WEAVE_TRACES_TOOL_DESCRIPTION = """
19
  Query Weave traces, trace metadata, and trace costs with filtering and sorting options.
 
308
  but delegates to our new implementation.
309
  """
310
  # If api_key was provided, create a new service with that key
311
+ service = get_trace_service()
312
  if api_key:
313
  service = TraceService(
314
  api_key=api_key,
 
417
  QueryResult: A Pydantic model containing the query results
418
  """
419
  # If api_key was provided, create a new service with that key
420
+ service = get_trace_service()
421
  if api_key:
422
  service = TraceService(
423
  api_key=api_key,
src/wandb_mcp_server/server.py CHANGED
@@ -119,10 +119,13 @@ def setup_wandb_login(api_key: str) -> None:
119
  sys.stderr = original_stderr
120
 
121
 
122
- def validate_and_get_api_key(args: ServerMCPArgs) -> str:
123
  """
124
  Validate and retrieve the W&B API key from various sources.
125
 
 
 
 
126
  Priority order:
127
  1. Command-line argument (--wandb-api-key)
128
  2. Environment variable (WANDB_API_KEY)
@@ -133,16 +136,25 @@ def validate_and_get_api_key(args: ServerMCPArgs) -> str:
133
  args: Parsed command-line arguments
134
 
135
  Returns:
136
- The W&B API key
137
 
138
  Raises:
139
- ValueError: If no API key is found
140
  """
141
  api_key = args.wandb_api_key or get_server_args().wandb_api_key
142
 
 
 
 
 
 
 
 
 
 
143
  if not api_key:
144
  raise ValueError(
145
- "WANDB_API_KEY must be set. Options:\n"
146
  "1. Command-line: --wandb-api-key YOUR_KEY\n"
147
  "2. Environment: export WANDB_API_KEY=YOUR_KEY\n"
148
  "3. .env file: WANDB_API_KEY=YOUR_KEY\n"
@@ -387,14 +399,27 @@ def create_mcp_server(transport: str, host: str = "localhost", port: Optional[in
387
 
388
  Raises:
389
  ValueError: If transport type is invalid
 
 
 
 
 
390
  """
391
  if transport == "http":
392
  port = port if port is not None else 8080
393
  logger.info(f"Configuring HTTP server on {host}:{port}")
394
  mcp = FastMCP("weave-mcp-server", host=host, port=port, stateless_http=True)
 
 
 
 
 
 
 
395
  elif transport == "stdio":
396
  logger.info("Configuring stdio server")
397
  mcp = FastMCP("weave-mcp-server")
 
398
  else:
399
  raise ValueError(f"Invalid transport type: {transport}. Must be 'stdio' or 'http'")
400
 
@@ -422,11 +447,12 @@ def cli():
422
  --wandb-api-key KEY W&B API key (can also use env var)
423
 
424
  Environment Variables:
425
- WANDB_API_KEY W&B API key for authentication
426
  MCP_SERVER_LOG_LEVEL Server log level (DEBUG, INFO, WARNING, ERROR)
427
  WANDB_SILENT Set to "False" to enable W&B output (default: True)
428
  WEAVE_SILENT Set to "False" to enable Weave output (default: True)
429
  WANDB_DEBUG Set to "true" to enable W&B debug logging
 
430
  """
431
  # Parse command line arguments
432
  import simple_parsing
@@ -438,8 +464,9 @@ def cli():
438
  # Validate and get API key
439
  api_key = validate_and_get_api_key(args)
440
 
441
- # Perform W&B login
442
- setup_wandb_login(api_key)
 
443
 
444
  # Initialize Weave tracing for MCP tool calls
445
  weave_initialized = initialize_weave_tracing()
 
119
  sys.stderr = original_stderr
120
 
121
 
122
+ def validate_and_get_api_key(args: ServerMCPArgs) -> Optional[str]:
123
  """
124
  Validate and retrieve the W&B API key from various sources.
125
 
126
+ For HTTP transport: API key is optional (clients provide their own)
127
+ For STDIO transport: API key is required from environment
128
+
129
  Priority order:
130
  1. Command-line argument (--wandb-api-key)
131
  2. Environment variable (WANDB_API_KEY)
 
136
  args: Parsed command-line arguments
137
 
138
  Returns:
139
+ The W&B API key if found, None otherwise
140
 
141
  Raises:
142
+ ValueError: If no API key is found for STDIO transport
143
  """
144
  api_key = args.wandb_api_key or get_server_args().wandb_api_key
145
 
146
+ # For HTTP transport, API key is optional (clients provide their own)
147
+ if args.transport == "http":
148
+ if api_key:
149
+ logger.info("Server W&B API key configured (for server operations)")
150
+ else:
151
+ logger.info("No server W&B API key configured (clients will provide their own)")
152
+ return api_key
153
+
154
+ # For STDIO transport, API key is required
155
  if not api_key:
156
  raise ValueError(
157
+ "WANDB_API_KEY must be set for STDIO transport. Options:\n"
158
  "1. Command-line: --wandb-api-key YOUR_KEY\n"
159
  "2. Environment: export WANDB_API_KEY=YOUR_KEY\n"
160
  "3. .env file: WANDB_API_KEY=YOUR_KEY\n"
 
399
 
400
  Raises:
401
  ValueError: If transport type is invalid
402
+
403
+ Authentication:
404
+ - STDIO transport: Uses environment variables (WANDB_API_KEY required)
405
+ - HTTP transport: Clients provide W&B API key as Bearer token
406
+ Set MCP_AUTH_DISABLED=true to disable auth (development only)
407
  """
408
  if transport == "http":
409
  port = port if port is not None else 8080
410
  logger.info(f"Configuring HTTP server on {host}:{port}")
411
  mcp = FastMCP("weave-mcp-server", host=host, port=port, stateless_http=True)
412
+
413
+ # Log authentication status for HTTP
414
+ if os.environ.get("MCP_AUTH_DISABLED", "false").lower() == "true":
415
+ logger.warning("⚠️ MCP authentication is DISABLED - server is publicly accessible")
416
+ else:
417
+ logger.info("🔒 MCP authentication enabled - clients must provide W&B API key as Bearer token")
418
+
419
  elif transport == "stdio":
420
  logger.info("Configuring stdio server")
421
  mcp = FastMCP("weave-mcp-server")
422
+ logger.info("STDIO transport uses environment variable authentication")
423
  else:
424
  raise ValueError(f"Invalid transport type: {transport}. Must be 'stdio' or 'http'")
425
 
 
447
  --wandb-api-key KEY W&B API key (can also use env var)
448
 
449
  Environment Variables:
450
+ WANDB_API_KEY W&B API key (required for STDIO, optional for HTTP)
451
  MCP_SERVER_LOG_LEVEL Server log level (DEBUG, INFO, WARNING, ERROR)
452
  WANDB_SILENT Set to "False" to enable W&B output (default: True)
453
  WEAVE_SILENT Set to "False" to enable Weave output (default: True)
454
  WANDB_DEBUG Set to "true" to enable W&B debug logging
455
+ MCP_AUTH_DISABLED Set to "true" to disable HTTP auth (dev only)
456
  """
457
  # Parse command line arguments
458
  import simple_parsing
 
464
  # Validate and get API key
465
  api_key = validate_and_get_api_key(args)
466
 
467
+ # Perform W&B login only if we have an API key
468
+ if api_key:
469
+ setup_wandb_login(api_key)
470
 
471
  # Initialize Weave tracing for MCP tool calls
472
  weave_initialized = initialize_weave_tracing()