Spaces:
Running
Running
Yorrick Jansen
commited on
Commit
·
313d1c4
1
Parent(s):
3975e9b
Iterate on ci/cd
Browse files- pyproject.toml +3 -0
- strava_mcp/config.py +1 -1
- tests/test_auth.py +268 -0
- tests/test_config.py +70 -0
- tests/test_models.py +247 -0
- tests/test_oauth_server.py +278 -0
- tests/test_server.py +9 -2
- uv.lock +1 -1
pyproject.toml
CHANGED
|
@@ -28,6 +28,9 @@ dev = [
|
|
| 28 |
requires = ["hatchling"]
|
| 29 |
build-backend = "hatchling.build"
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
[tool.ruff]
|
| 32 |
line-length = 120
|
| 33 |
target-version = "py313"
|
|
|
|
| 28 |
requires = ["hatchling"]
|
| 29 |
build-backend = "hatchling.build"
|
| 30 |
|
| 31 |
+
[tool.hatch.build.targets.wheel]
|
| 32 |
+
packages = ["strava_mcp"]
|
| 33 |
+
|
| 34 |
[tool.ruff]
|
| 35 |
line-length = 120
|
| 36 |
target-version = "py313"
|
strava_mcp/config.py
CHANGED
|
@@ -8,7 +8,7 @@ class StravaSettings(BaseSettings):
|
|
| 8 |
client_id: str = Field(..., description="Strava API client ID")
|
| 9 |
client_secret: str = Field(..., description="Strava API client secret")
|
| 10 |
refresh_token: str | None = Field(
|
| 11 |
-
None,
|
| 12 |
description="Strava API refresh token (can be generated through auth flow)",
|
| 13 |
)
|
| 14 |
base_url: str = Field("https://www.strava.com/api/v3", description="Strava API base URL")
|
|
|
|
| 8 |
client_id: str = Field(..., description="Strava API client ID")
|
| 9 |
client_secret: str = Field(..., description="Strava API client secret")
|
| 10 |
refresh_token: str | None = Field(
|
| 11 |
+
default=None,
|
| 12 |
description="Strava API refresh token (can be generated through auth flow)",
|
| 13 |
)
|
| 14 |
base_url: str = Field("https://www.strava.com/api/v3", description="Strava API base URL")
|
tests/test_auth.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the Strava authentication module."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
from fastapi.testclient import TestClient
|
| 9 |
+
from httpx import Response
|
| 10 |
+
|
| 11 |
+
from strava_mcp.auth import StravaAuthenticator, TokenResponse, get_strava_refresh_token
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def client_credentials():
|
| 16 |
+
"""Fixture for client credentials."""
|
| 17 |
+
return {
|
| 18 |
+
"client_id": "test_client_id",
|
| 19 |
+
"client_secret": "test_client_secret",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.fixture
|
| 24 |
+
def mock_token_response():
|
| 25 |
+
"""Fixture for token response."""
|
| 26 |
+
return {
|
| 27 |
+
"access_token": "test_access_token",
|
| 28 |
+
"refresh_token": "test_refresh_token",
|
| 29 |
+
"expires_at": 1609459200,
|
| 30 |
+
"expires_in": 21600,
|
| 31 |
+
"token_type": "Bearer",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.fixture
|
| 36 |
+
def fastapi_app():
|
| 37 |
+
"""Fixture for FastAPI app."""
|
| 38 |
+
return FastAPI()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.fixture
|
| 42 |
+
def authenticator(client_credentials, fastapi_app):
|
| 43 |
+
"""Fixture for StravaAuthenticator."""
|
| 44 |
+
return StravaAuthenticator(
|
| 45 |
+
client_id=client_credentials["client_id"],
|
| 46 |
+
client_secret=client_credentials["client_secret"],
|
| 47 |
+
app=fastapi_app,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_get_authorization_url(authenticator):
|
| 52 |
+
"""Test getting the authorization URL."""
|
| 53 |
+
url = authenticator.get_authorization_url()
|
| 54 |
+
|
| 55 |
+
# Check that the URL contains the expected parameters
|
| 56 |
+
assert "https://www.strava.com/oauth/authorize" in url
|
| 57 |
+
assert f"client_id={authenticator.client_id}" in url
|
| 58 |
+
# URL is encoded, so we need to check the non-encoded parts
|
| 59 |
+
assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A3008%2Fexchange_token" in url
|
| 60 |
+
assert "response_type=code" in url
|
| 61 |
+
assert "scope=" in url
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_setup_routes(authenticator, fastapi_app):
|
| 65 |
+
"""Test setting up routes."""
|
| 66 |
+
authenticator.setup_routes(fastapi_app)
|
| 67 |
+
|
| 68 |
+
# Check that the routes were added
|
| 69 |
+
routes = [route.path for route in fastapi_app.routes]
|
| 70 |
+
assert authenticator.redirect_path in routes
|
| 71 |
+
assert "/auth" in routes
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_setup_routes_no_app(authenticator):
|
| 75 |
+
"""Test setting up routes with no app."""
|
| 76 |
+
authenticator.app = None
|
| 77 |
+
with pytest.raises(ValueError, match="No FastAPI app provided"):
|
| 78 |
+
authenticator.setup_routes()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.mark.asyncio
|
| 82 |
+
async def test_exchange_token_success(authenticator, mock_token_response):
|
| 83 |
+
"""Test exchanging token successfully."""
|
| 84 |
+
# Setup mock
|
| 85 |
+
with patch("httpx.AsyncClient") as mock_client:
|
| 86 |
+
mock_response = MagicMock(spec=Response)
|
| 87 |
+
mock_response.status_code = 200
|
| 88 |
+
mock_response.json.return_value = mock_token_response
|
| 89 |
+
mock_client.return_value.__aenter__.return_value.post.return_value = mock_response
|
| 90 |
+
|
| 91 |
+
# Set up a future to receive the token
|
| 92 |
+
authenticator.token_future = asyncio.Future()
|
| 93 |
+
|
| 94 |
+
# Call the handler
|
| 95 |
+
response = await authenticator.exchange_token(code="test_code")
|
| 96 |
+
|
| 97 |
+
# Check response
|
| 98 |
+
assert response.status_code == 200
|
| 99 |
+
assert "Authorization successful" in response.body.decode()
|
| 100 |
+
|
| 101 |
+
# Check token future
|
| 102 |
+
assert authenticator.token_future.done()
|
| 103 |
+
assert await authenticator.token_future == "test_refresh_token"
|
| 104 |
+
|
| 105 |
+
# Check token was saved
|
| 106 |
+
assert authenticator.refresh_token == "test_refresh_token"
|
| 107 |
+
|
| 108 |
+
# Verify correct API call
|
| 109 |
+
mock_client.return_value.__aenter__.return_value.post.assert_called_once()
|
| 110 |
+
args, kwargs = mock_client.return_value.__aenter__.return_value.post.call_args
|
| 111 |
+
assert args[0] == "https://www.strava.com/oauth/token"
|
| 112 |
+
assert kwargs["data"]["client_id"] == authenticator.client_id
|
| 113 |
+
assert kwargs["data"]["client_secret"] == authenticator.client_secret
|
| 114 |
+
assert kwargs["data"]["code"] == "test_code"
|
| 115 |
+
assert kwargs["data"]["grant_type"] == "authorization_code"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@pytest.mark.asyncio
|
| 119 |
+
async def test_exchange_token_failure(authenticator):
|
| 120 |
+
"""Test exchanging token with failure."""
|
| 121 |
+
# Setup mock
|
| 122 |
+
with patch("httpx.AsyncClient") as mock_client:
|
| 123 |
+
mock_response = MagicMock(spec=Response)
|
| 124 |
+
mock_response.status_code = 400
|
| 125 |
+
mock_response.text = "Invalid code"
|
| 126 |
+
mock_client.return_value.__aenter__.return_value.post.return_value = mock_response
|
| 127 |
+
|
| 128 |
+
# Set up a future to receive the token
|
| 129 |
+
authenticator.token_future = asyncio.Future()
|
| 130 |
+
|
| 131 |
+
# Call the handler
|
| 132 |
+
response = await authenticator.exchange_token(code="invalid_code")
|
| 133 |
+
|
| 134 |
+
# Check response
|
| 135 |
+
assert response.status_code == 200
|
| 136 |
+
assert "Authorization failed" in response.body.decode()
|
| 137 |
+
|
| 138 |
+
# Check token future
|
| 139 |
+
assert authenticator.token_future.done()
|
| 140 |
+
with pytest.raises(Exception):
|
| 141 |
+
await authenticator.token_future
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@pytest.mark.asyncio
|
| 145 |
+
async def test_start_auth_flow(authenticator):
|
| 146 |
+
"""Test starting auth flow."""
|
| 147 |
+
with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
|
| 148 |
+
response = await authenticator.start_auth_flow()
|
| 149 |
+
assert response.status_code == 307
|
| 150 |
+
assert response.headers["location"] == "https://example.com/auth"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@pytest.mark.asyncio
|
| 154 |
+
async def test_get_refresh_token(authenticator):
|
| 155 |
+
"""Test getting refresh token."""
|
| 156 |
+
# Mock the webbrowser.open call
|
| 157 |
+
with patch("webbrowser.open", return_value=True) as mock_open:
|
| 158 |
+
with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
|
| 159 |
+
# Set the future result after a delay
|
| 160 |
+
authenticator.token_future = None # Reset it so a new one is created
|
| 161 |
+
|
| 162 |
+
# Start the token request in background
|
| 163 |
+
task = asyncio.create_task(authenticator.get_refresh_token())
|
| 164 |
+
|
| 165 |
+
# Wait a bit and set the result
|
| 166 |
+
await asyncio.sleep(0.1)
|
| 167 |
+
authenticator.token_future.set_result("test_refresh_token")
|
| 168 |
+
|
| 169 |
+
# Get the result
|
| 170 |
+
token = await task
|
| 171 |
+
|
| 172 |
+
# Verify
|
| 173 |
+
assert token == "test_refresh_token"
|
| 174 |
+
mock_open.assert_called_once_with("https://example.com/auth")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@pytest.mark.asyncio
|
| 178 |
+
async def test_get_refresh_token_no_browser(authenticator):
|
| 179 |
+
"""Test getting refresh token without opening browser."""
|
| 180 |
+
with patch("webbrowser.open") as mock_open:
|
| 181 |
+
with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
|
| 182 |
+
# Set the future result after a delay
|
| 183 |
+
authenticator.token_future = None # Reset it so a new one is created
|
| 184 |
+
|
| 185 |
+
# Start the token request in background
|
| 186 |
+
task = asyncio.create_task(authenticator.get_refresh_token(open_browser=False))
|
| 187 |
+
|
| 188 |
+
# Wait a bit and set the result
|
| 189 |
+
await asyncio.sleep(0.1)
|
| 190 |
+
authenticator.token_future.set_result("test_refresh_token")
|
| 191 |
+
|
| 192 |
+
# Get the result
|
| 193 |
+
token = await task
|
| 194 |
+
|
| 195 |
+
# Verify
|
| 196 |
+
assert token == "test_refresh_token"
|
| 197 |
+
mock_open.assert_not_called()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@pytest.mark.asyncio
|
| 201 |
+
async def test_get_refresh_token_browser_fails(authenticator):
|
| 202 |
+
"""Test getting refresh token with browser opening failing."""
|
| 203 |
+
with patch("webbrowser.open", return_value=False) as mock_open:
|
| 204 |
+
with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
|
| 205 |
+
# Set the future result after a delay
|
| 206 |
+
authenticator.token_future = None # Reset it so a new one is created
|
| 207 |
+
|
| 208 |
+
# Start the token request in background
|
| 209 |
+
task = asyncio.create_task(authenticator.get_refresh_token())
|
| 210 |
+
|
| 211 |
+
# Wait a bit and set the result
|
| 212 |
+
await asyncio.sleep(0.1)
|
| 213 |
+
authenticator.token_future.set_result("test_refresh_token")
|
| 214 |
+
|
| 215 |
+
# Get the result
|
| 216 |
+
token = await task
|
| 217 |
+
|
| 218 |
+
# Verify
|
| 219 |
+
assert token == "test_refresh_token"
|
| 220 |
+
mock_open.assert_called_once_with("https://example.com/auth")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@pytest.mark.asyncio
|
| 224 |
+
async def test_get_strava_refresh_token(client_credentials):
|
| 225 |
+
"""Test get_strava_refresh_token function."""
|
| 226 |
+
with patch("strava_mcp.auth.StravaAuthenticator") as MockAuthenticator:
|
| 227 |
+
# Setup mock
|
| 228 |
+
mock_authenticator = MagicMock()
|
| 229 |
+
mock_authenticator.get_refresh_token = AsyncMock(return_value="test_refresh_token")
|
| 230 |
+
mock_authenticator.setup_routes = MagicMock()
|
| 231 |
+
MockAuthenticator.return_value = mock_authenticator
|
| 232 |
+
|
| 233 |
+
# Test without app
|
| 234 |
+
token = await get_strava_refresh_token(
|
| 235 |
+
client_credentials["client_id"],
|
| 236 |
+
client_credentials["client_secret"]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Verify
|
| 240 |
+
assert token == "test_refresh_token"
|
| 241 |
+
MockAuthenticator.assert_called_once_with(
|
| 242 |
+
client_credentials["client_id"],
|
| 243 |
+
client_credentials["client_secret"],
|
| 244 |
+
None
|
| 245 |
+
)
|
| 246 |
+
mock_authenticator.setup_routes.assert_not_called()
|
| 247 |
+
|
| 248 |
+
# Reset mocks
|
| 249 |
+
MockAuthenticator.reset_mock()
|
| 250 |
+
mock_authenticator.get_refresh_token.reset_mock()
|
| 251 |
+
mock_authenticator.setup_routes.reset_mock()
|
| 252 |
+
|
| 253 |
+
# Test with app
|
| 254 |
+
app = FastAPI()
|
| 255 |
+
token = await get_strava_refresh_token(
|
| 256 |
+
client_credentials["client_id"],
|
| 257 |
+
client_credentials["client_secret"],
|
| 258 |
+
app
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Verify
|
| 262 |
+
assert token == "test_refresh_token"
|
| 263 |
+
MockAuthenticator.assert_called_once_with(
|
| 264 |
+
client_credentials["client_id"],
|
| 265 |
+
client_credentials["client_secret"],
|
| 266 |
+
app
|
| 267 |
+
)
|
| 268 |
+
mock_authenticator.setup_routes.assert_called_once_with(app)
|
tests/test_config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for configuration module."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from unittest import mock
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from strava_mcp.config import StravaSettings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_strava_settings_defaults():
|
| 12 |
+
"""Test default settings for StravaSettings."""
|
| 13 |
+
# Use required parameters only
|
| 14 |
+
with mock.patch.dict(os.environ, {}, clear=True):
|
| 15 |
+
settings = StravaSettings(
|
| 16 |
+
client_id="test_client_id",
|
| 17 |
+
client_secret="test_client_secret",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
assert settings.client_id == "test_client_id"
|
| 21 |
+
assert settings.client_secret == "test_client_secret"
|
| 22 |
+
assert settings.refresh_token is None
|
| 23 |
+
assert settings.base_url == "https://www.strava.com/api/v3"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_strava_settings_from_env():
|
| 27 |
+
"""Test loading settings from environment variables."""
|
| 28 |
+
with mock.patch.dict(
|
| 29 |
+
os.environ,
|
| 30 |
+
{
|
| 31 |
+
"STRAVA_CLIENT_ID": "env_client_id",
|
| 32 |
+
"STRAVA_CLIENT_SECRET": "env_client_secret",
|
| 33 |
+
"STRAVA_REFRESH_TOKEN": "env_refresh_token",
|
| 34 |
+
"STRAVA_BASE_URL": "https://custom.strava.api/v3",
|
| 35 |
+
},
|
| 36 |
+
):
|
| 37 |
+
settings = StravaSettings()
|
| 38 |
+
|
| 39 |
+
assert settings.client_id == "env_client_id"
|
| 40 |
+
assert settings.client_secret == "env_client_secret"
|
| 41 |
+
assert settings.refresh_token == "env_refresh_token"
|
| 42 |
+
assert settings.base_url == "https://custom.strava.api/v3"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_strava_settings_override():
|
| 46 |
+
"""Test overriding environment settings with direct values."""
|
| 47 |
+
with mock.patch.dict(
|
| 48 |
+
os.environ,
|
| 49 |
+
{
|
| 50 |
+
"STRAVA_CLIENT_ID": "env_client_id",
|
| 51 |
+
"STRAVA_CLIENT_SECRET": "env_client_secret",
|
| 52 |
+
"STRAVA_REFRESH_TOKEN": "env_refresh_token",
|
| 53 |
+
},
|
| 54 |
+
):
|
| 55 |
+
settings = StravaSettings(
|
| 56 |
+
client_id="direct_client_id",
|
| 57 |
+
refresh_token="direct_refresh_token",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Direct values should override environment variables
|
| 61 |
+
assert settings.client_id == "direct_client_id"
|
| 62 |
+
assert settings.client_secret == "env_client_secret"
|
| 63 |
+
assert settings.refresh_token == "direct_refresh_token"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_strava_settings_model_config():
|
| 67 |
+
"""Test model configuration for StravaSettings."""
|
| 68 |
+
assert StravaSettings.model_config["env_prefix"] == "STRAVA_"
|
| 69 |
+
assert StravaSettings.model_config["env_file"] == ".env"
|
| 70 |
+
assert StravaSettings.model_config["env_file_encoding"] == "utf-8"
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the Strava models."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from pydantic import ValidationError
|
| 6 |
+
|
| 7 |
+
from strava_mcp.models import Activity, DetailedActivity, Segment, SegmentEffort, ErrorResponse
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.fixture
|
| 11 |
+
def activity_data():
|
| 12 |
+
"""Fixture with valid activity data."""
|
| 13 |
+
return {
|
| 14 |
+
"id": 1234567890,
|
| 15 |
+
"name": "Morning Run",
|
| 16 |
+
"distance": 5000.0,
|
| 17 |
+
"moving_time": 1200,
|
| 18 |
+
"elapsed_time": 1300,
|
| 19 |
+
"total_elevation_gain": 50.0,
|
| 20 |
+
"type": "Run",
|
| 21 |
+
"sport_type": "Run",
|
| 22 |
+
"start_date": "2023-01-01T10:00:00Z",
|
| 23 |
+
"start_date_local": "2023-01-01T10:00:00Z",
|
| 24 |
+
"timezone": "Europe/London",
|
| 25 |
+
"achievement_count": 2,
|
| 26 |
+
"kudos_count": 5,
|
| 27 |
+
"comment_count": 0,
|
| 28 |
+
"athlete_count": 1,
|
| 29 |
+
"photo_count": 0,
|
| 30 |
+
"trainer": False,
|
| 31 |
+
"commute": False,
|
| 32 |
+
"manual": False,
|
| 33 |
+
"private": False,
|
| 34 |
+
"flagged": False,
|
| 35 |
+
"average_speed": 4.167,
|
| 36 |
+
"max_speed": 5.3,
|
| 37 |
+
"has_heartrate": True,
|
| 38 |
+
"average_heartrate": 140.0,
|
| 39 |
+
"max_heartrate": 160.0,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@pytest.fixture
|
| 44 |
+
def detailed_activity_data(activity_data):
|
| 45 |
+
"""Fixture with valid detailed activity data."""
|
| 46 |
+
return {
|
| 47 |
+
**activity_data,
|
| 48 |
+
"description": "Test description",
|
| 49 |
+
"athlete": {"id": 123},
|
| 50 |
+
"calories": 500.0,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.fixture
|
| 55 |
+
def segment_data():
|
| 56 |
+
"""Fixture with valid segment data."""
|
| 57 |
+
return {
|
| 58 |
+
"id": 12345,
|
| 59 |
+
"name": "Test Segment",
|
| 60 |
+
"activity_type": "Run",
|
| 61 |
+
"distance": 1000.0,
|
| 62 |
+
"average_grade": 5.0,
|
| 63 |
+
"maximum_grade": 10.0,
|
| 64 |
+
"elevation_high": 200.0,
|
| 65 |
+
"elevation_low": 150.0,
|
| 66 |
+
"total_elevation_gain": 50.0,
|
| 67 |
+
"start_latlng": [51.5, -0.1],
|
| 68 |
+
"end_latlng": [51.5, -0.2],
|
| 69 |
+
"climb_category": 0,
|
| 70 |
+
"private": False,
|
| 71 |
+
"starred": False,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@pytest.fixture
|
| 76 |
+
def segment_effort_data(segment_data):
|
| 77 |
+
"""Fixture with valid segment effort data."""
|
| 78 |
+
return {
|
| 79 |
+
"id": 67890,
|
| 80 |
+
"activity_id": 1234567890,
|
| 81 |
+
"segment_id": 12345,
|
| 82 |
+
"name": "Test Segment",
|
| 83 |
+
"elapsed_time": 180,
|
| 84 |
+
"moving_time": 180,
|
| 85 |
+
"start_date": "2023-01-01T10:05:00Z",
|
| 86 |
+
"start_date_local": "2023-01-01T10:05:00Z",
|
| 87 |
+
"distance": 1000.0,
|
| 88 |
+
"athlete": {"id": 123},
|
| 89 |
+
"segment": segment_data,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_activity_model(activity_data):
|
| 94 |
+
"""Test the Activity model."""
|
| 95 |
+
activity = Activity(**activity_data)
|
| 96 |
+
|
| 97 |
+
assert activity.id == activity_data["id"]
|
| 98 |
+
assert activity.name == activity_data["name"]
|
| 99 |
+
assert activity.distance == activity_data["distance"]
|
| 100 |
+
assert activity.start_date == datetime.fromisoformat(activity_data["start_date"].replace("Z", "+00:00"))
|
| 101 |
+
assert activity.start_date_local == datetime.fromisoformat(activity_data["start_date_local"].replace("Z", "+00:00"))
|
| 102 |
+
assert activity.average_heartrate == activity_data["average_heartrate"]
|
| 103 |
+
assert activity.max_heartrate == activity_data["max_heartrate"]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_activity_model_optional_fields(activity_data):
|
| 107 |
+
"""Test the Activity model with optional fields."""
|
| 108 |
+
# Remove some optional fields
|
| 109 |
+
data = activity_data.copy()
|
| 110 |
+
data.pop("average_heartrate")
|
| 111 |
+
data.pop("max_heartrate")
|
| 112 |
+
|
| 113 |
+
activity = Activity(**data)
|
| 114 |
+
|
| 115 |
+
assert activity.average_heartrate is None
|
| 116 |
+
assert activity.max_heartrate is None
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def test_activity_model_missing_required_fields(activity_data):
|
| 120 |
+
"""Test the Activity model with missing required fields."""
|
| 121 |
+
data = activity_data.copy()
|
| 122 |
+
data.pop("id") # Remove a required field
|
| 123 |
+
|
| 124 |
+
with pytest.raises(ValidationError):
|
| 125 |
+
Activity(**data)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_detailed_activity_model(detailed_activity_data):
|
| 129 |
+
"""Test the DetailedActivity model."""
|
| 130 |
+
activity = DetailedActivity(**detailed_activity_data)
|
| 131 |
+
|
| 132 |
+
assert activity.id == detailed_activity_data["id"]
|
| 133 |
+
assert activity.name == detailed_activity_data["name"]
|
| 134 |
+
assert activity.description == detailed_activity_data["description"]
|
| 135 |
+
assert activity.athlete == detailed_activity_data["athlete"]
|
| 136 |
+
assert activity.calories == detailed_activity_data["calories"]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_detailed_activity_optional_fields(detailed_activity_data):
|
| 140 |
+
"""Test the DetailedActivity model with optional fields."""
|
| 141 |
+
data = detailed_activity_data.copy()
|
| 142 |
+
data.pop("description")
|
| 143 |
+
data.pop("calories")
|
| 144 |
+
|
| 145 |
+
activity = DetailedActivity(**data)
|
| 146 |
+
|
| 147 |
+
assert activity.description is None
|
| 148 |
+
assert activity.calories is None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_segment_model(segment_data):
|
| 152 |
+
"""Test the Segment model."""
|
| 153 |
+
segment = Segment(**segment_data)
|
| 154 |
+
|
| 155 |
+
assert segment.id == segment_data["id"]
|
| 156 |
+
assert segment.name == segment_data["name"]
|
| 157 |
+
assert segment.activity_type == segment_data["activity_type"]
|
| 158 |
+
assert segment.distance == segment_data["distance"]
|
| 159 |
+
assert segment.start_latlng == segment_data["start_latlng"]
|
| 160 |
+
assert segment.end_latlng == segment_data["end_latlng"]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def test_segment_optional_fields(segment_data):
|
| 164 |
+
"""Test the Segment model with optional fields."""
|
| 165 |
+
# Add some optional fields
|
| 166 |
+
data = segment_data.copy()
|
| 167 |
+
data["city"] = "London"
|
| 168 |
+
data["state"] = "Greater London"
|
| 169 |
+
data["country"] = "United Kingdom"
|
| 170 |
+
|
| 171 |
+
segment = Segment(**data)
|
| 172 |
+
|
| 173 |
+
assert segment.city == "London"
|
| 174 |
+
assert segment.state == "Greater London"
|
| 175 |
+
assert segment.country == "United Kingdom"
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def test_segment_missing_fields(segment_data):
|
| 179 |
+
"""Test the Segment model with missing required fields."""
|
| 180 |
+
data = segment_data.copy()
|
| 181 |
+
data.pop("id") # Remove a required field
|
| 182 |
+
|
| 183 |
+
with pytest.raises(ValidationError):
|
| 184 |
+
Segment(**data)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def test_segment_effort_model(segment_effort_data):
|
| 188 |
+
"""Test the SegmentEffort model."""
|
| 189 |
+
effort = SegmentEffort(**segment_effort_data)
|
| 190 |
+
|
| 191 |
+
assert effort.id == segment_effort_data["id"]
|
| 192 |
+
assert effort.activity_id == segment_effort_data["activity_id"]
|
| 193 |
+
assert effort.segment_id == segment_effort_data["segment_id"]
|
| 194 |
+
assert effort.name == segment_effort_data["name"]
|
| 195 |
+
assert effort.elapsed_time == segment_effort_data["elapsed_time"]
|
| 196 |
+
assert effort.moving_time == segment_effort_data["moving_time"]
|
| 197 |
+
assert effort.start_date == datetime.fromisoformat(segment_effort_data["start_date"].replace("Z", "+00:00"))
|
| 198 |
+
assert effort.start_date_local == datetime.fromisoformat(segment_effort_data["start_date_local"].replace("Z", "+00:00"))
|
| 199 |
+
assert effort.distance == segment_effort_data["distance"]
|
| 200 |
+
assert effort.athlete == segment_effort_data["athlete"]
|
| 201 |
+
|
| 202 |
+
# Test nested segment object
|
| 203 |
+
assert effort.segment.id == segment_effort_data["segment"]["id"]
|
| 204 |
+
assert effort.segment.name == segment_effort_data["segment"]["name"]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_segment_effort_optional_fields(segment_effort_data):
|
| 208 |
+
"""Test the SegmentEffort model with optional fields."""
|
| 209 |
+
# Add some optional fields
|
| 210 |
+
data = segment_effort_data.copy()
|
| 211 |
+
data["average_watts"] = 200.0
|
| 212 |
+
data["device_watts"] = True
|
| 213 |
+
data["average_heartrate"] = 150.0
|
| 214 |
+
data["max_heartrate"] = 170.0
|
| 215 |
+
data["pr_rank"] = 1
|
| 216 |
+
data["achievements"] = [{"type": "overall", "rank": 1}]
|
| 217 |
+
|
| 218 |
+
effort = SegmentEffort(**data)
|
| 219 |
+
|
| 220 |
+
assert effort.average_watts == 200.0
|
| 221 |
+
assert effort.device_watts is True
|
| 222 |
+
assert effort.average_heartrate == 150.0
|
| 223 |
+
assert effort.max_heartrate == 170.0
|
| 224 |
+
assert effort.pr_rank == 1
|
| 225 |
+
assert effort.achievements == [{"type": "overall", "rank": 1}]
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def test_segment_effort_missing_fields(segment_effort_data):
|
| 229 |
+
"""Test the SegmentEffort model with missing required fields."""
|
| 230 |
+
data = segment_effort_data.copy()
|
| 231 |
+
data.pop("segment") # Remove a required field
|
| 232 |
+
|
| 233 |
+
with pytest.raises(ValidationError):
|
| 234 |
+
SegmentEffort(**data)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def test_error_response():
|
| 238 |
+
"""Test the ErrorResponse model."""
|
| 239 |
+
data = {
|
| 240 |
+
"message": "Resource not found",
|
| 241 |
+
"code": 404
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
error = ErrorResponse(**data)
|
| 245 |
+
|
| 246 |
+
assert error.message == "Resource not found"
|
| 247 |
+
assert error.code == 404
|
tests/test_oauth_server.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the Strava OAuth server module."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import uvicorn
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
|
| 10 |
+
from strava_mcp.auth import StravaAuthenticator
|
| 11 |
+
from strava_mcp.oauth_server import StravaOAuthServer, get_refresh_token_from_oauth
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def client_credentials():
|
| 16 |
+
"""Fixture for client credentials."""
|
| 17 |
+
return {
|
| 18 |
+
"client_id": "test_client_id",
|
| 19 |
+
"client_secret": "test_client_secret",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.fixture
|
| 24 |
+
def oauth_server(client_credentials):
|
| 25 |
+
"""Fixture for StravaOAuthServer."""
|
| 26 |
+
return StravaOAuthServer(
|
| 27 |
+
client_id=client_credentials["client_id"],
|
| 28 |
+
client_secret=client_credentials["client_secret"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.mark.asyncio
|
| 33 |
+
async def test_initialize_server(oauth_server):
|
| 34 |
+
"""Test initializing the server."""
|
| 35 |
+
# Mock the OAuth server's dependencies directly
|
| 36 |
+
with patch("strava_mcp.oauth_server.StravaAuthenticator") as MockAuthenticator:
|
| 37 |
+
with patch("asyncio.create_task") as mock_create_task:
|
| 38 |
+
# Setup mocks
|
| 39 |
+
mock_authenticator = MagicMock()
|
| 40 |
+
MockAuthenticator.return_value = mock_authenticator
|
| 41 |
+
mock_task = MagicMock()
|
| 42 |
+
mock_create_task.return_value = mock_task
|
| 43 |
+
|
| 44 |
+
# Test method
|
| 45 |
+
await oauth_server._initialize_server()
|
| 46 |
+
|
| 47 |
+
# Verify FastAPI app was created
|
| 48 |
+
assert oauth_server.app is not None
|
| 49 |
+
assert oauth_server.app.title == "Strava OAuth"
|
| 50 |
+
|
| 51 |
+
# Verify authenticator was created and configured
|
| 52 |
+
MockAuthenticator.assert_called_once_with(
|
| 53 |
+
client_id=oauth_server.client_id,
|
| 54 |
+
client_secret=oauth_server.client_secret,
|
| 55 |
+
app=oauth_server.app,
|
| 56 |
+
host=oauth_server.host,
|
| 57 |
+
port=oauth_server.port,
|
| 58 |
+
)
|
| 59 |
+
assert oauth_server.authenticator == mock_authenticator
|
| 60 |
+
|
| 61 |
+
# Verify token future was stored in authenticator
|
| 62 |
+
assert mock_authenticator.token_future is oauth_server.token_future
|
| 63 |
+
|
| 64 |
+
# Verify routes were set up
|
| 65 |
+
mock_authenticator.setup_routes.assert_called_once_with(oauth_server.app)
|
| 66 |
+
|
| 67 |
+
# Verify server task was created
|
| 68 |
+
mock_create_task.assert_called_once()
|
| 69 |
+
assert oauth_server.server_task == mock_task
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@pytest.mark.asyncio
|
| 73 |
+
async def test_run_server(oauth_server):
|
| 74 |
+
"""Test running the server."""
|
| 75 |
+
with patch("uvicorn.Server") as MockServer:
|
| 76 |
+
with patch("uvicorn.Config") as MockConfig:
|
| 77 |
+
# Setup mocks
|
| 78 |
+
mock_server = AsyncMock()
|
| 79 |
+
MockServer.return_value = mock_server
|
| 80 |
+
mock_config = MagicMock()
|
| 81 |
+
MockConfig.return_value = mock_config
|
| 82 |
+
|
| 83 |
+
# Create app
|
| 84 |
+
oauth_server.app = FastAPI()
|
| 85 |
+
|
| 86 |
+
# Test method
|
| 87 |
+
await oauth_server._run_server()
|
| 88 |
+
|
| 89 |
+
# Verify config was created correctly
|
| 90 |
+
MockConfig.assert_called_once_with(
|
| 91 |
+
app=oauth_server.app,
|
| 92 |
+
host=oauth_server.host,
|
| 93 |
+
port=oauth_server.port,
|
| 94 |
+
log_level="info",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Verify server was created and run
|
| 98 |
+
MockServer.assert_called_once_with(mock_config)
|
| 99 |
+
mock_server.serve.assert_called_once()
|
| 100 |
+
assert oauth_server.server == mock_server
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@pytest.mark.asyncio
|
| 104 |
+
async def test_run_server_exception(oauth_server):
|
| 105 |
+
"""Test running the server with an exception."""
|
| 106 |
+
with patch("uvicorn.Server") as MockServer:
|
| 107 |
+
with patch("uvicorn.Config") as MockConfig:
|
| 108 |
+
# Setup mocks
|
| 109 |
+
mock_server = AsyncMock()
|
| 110 |
+
mock_server.serve = AsyncMock(side_effect=Exception("Test error"))
|
| 111 |
+
MockServer.return_value = mock_server
|
| 112 |
+
mock_config = MagicMock()
|
| 113 |
+
MockConfig.return_value = mock_config
|
| 114 |
+
|
| 115 |
+
# Create app and token future
|
| 116 |
+
oauth_server.app = FastAPI()
|
| 117 |
+
oauth_server.token_future = asyncio.Future()
|
| 118 |
+
|
| 119 |
+
# Test method
|
| 120 |
+
await oauth_server._run_server()
|
| 121 |
+
|
| 122 |
+
# Verify token future has exception
|
| 123 |
+
assert oauth_server.token_future.done()
|
| 124 |
+
with pytest.raises(Exception, match="Test error"):
|
| 125 |
+
await oauth_server.token_future
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@pytest.mark.asyncio
|
| 129 |
+
async def test_stop_server(oauth_server):
|
| 130 |
+
"""Test stopping the server."""
|
| 131 |
+
# Setup server and task
|
| 132 |
+
oauth_server.server = MagicMock()
|
| 133 |
+
oauth_server.server_task = MagicMock()
|
| 134 |
+
oauth_server.server_task.done = MagicMock(return_value=False)
|
| 135 |
+
|
| 136 |
+
# Make asyncio.wait_for return immediately
|
| 137 |
+
with patch("asyncio.wait_for", new=AsyncMock()) as mock_wait_for:
|
| 138 |
+
# Test method
|
| 139 |
+
await oauth_server._stop_server()
|
| 140 |
+
|
| 141 |
+
# Verify server was stopped
|
| 142 |
+
assert oauth_server.server.should_exit is True
|
| 143 |
+
mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@pytest.mark.asyncio
|
| 147 |
+
async def test_stop_server_timeout(oauth_server):
|
| 148 |
+
"""Test stopping the server with timeout."""
|
| 149 |
+
# Setup server and task
|
| 150 |
+
oauth_server.server = MagicMock()
|
| 151 |
+
oauth_server.server_task = MagicMock()
|
| 152 |
+
oauth_server.server_task.done = MagicMock(return_value=False)
|
| 153 |
+
|
| 154 |
+
# Make asyncio.wait_for raise TimeoutError
|
| 155 |
+
with patch("asyncio.wait_for", new=AsyncMock(side_effect=TimeoutError())) as mock_wait_for:
|
| 156 |
+
# Test method
|
| 157 |
+
await oauth_server._stop_server()
|
| 158 |
+
|
| 159 |
+
# Verify server was stopped
|
| 160 |
+
assert oauth_server.server.should_exit is True
|
| 161 |
+
mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@pytest.mark.asyncio
|
| 165 |
+
async def test_get_token(oauth_server):
|
| 166 |
+
"""Test getting a token."""
|
| 167 |
+
# Setup mocks
|
| 168 |
+
oauth_server._initialize_server = AsyncMock()
|
| 169 |
+
oauth_server._stop_server = AsyncMock()
|
| 170 |
+
oauth_server.authenticator = MagicMock()
|
| 171 |
+
oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
|
| 172 |
+
|
| 173 |
+
with patch("webbrowser.open") as mock_open:
|
| 174 |
+
# Prepare token future
|
| 175 |
+
oauth_server.token_future = asyncio.Future()
|
| 176 |
+
oauth_server.token_future.set_result("test_refresh_token")
|
| 177 |
+
|
| 178 |
+
# Test method
|
| 179 |
+
token = await oauth_server.get_token()
|
| 180 |
+
|
| 181 |
+
# Verify
|
| 182 |
+
assert token == "test_refresh_token"
|
| 183 |
+
oauth_server._initialize_server.assert_called_once()
|
| 184 |
+
oauth_server.authenticator.get_authorization_url.assert_called_once()
|
| 185 |
+
mock_open.assert_called_once_with("https://example.com/auth")
|
| 186 |
+
oauth_server._stop_server.assert_called_once()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@pytest.mark.asyncio
|
| 190 |
+
async def test_get_token_no_authenticator(oauth_server):
|
| 191 |
+
"""Test getting a token with no authenticator."""
|
| 192 |
+
# Setup mocks
|
| 193 |
+
oauth_server._initialize_server = AsyncMock()
|
| 194 |
+
oauth_server._stop_server = AsyncMock()
|
| 195 |
+
oauth_server.authenticator = None
|
| 196 |
+
|
| 197 |
+
# Test method
|
| 198 |
+
with pytest.raises(Exception, match="Authenticator not initialized"):
|
| 199 |
+
await oauth_server.get_token()
|
| 200 |
+
|
| 201 |
+
# Verify
|
| 202 |
+
oauth_server._initialize_server.assert_called_once()
|
| 203 |
+
# The stop server is not called because we exit with exception before getting there
|
| 204 |
+
# oauth_server._stop_server.assert_called_once()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@pytest.mark.asyncio
|
| 208 |
+
async def test_get_token_cancelled(oauth_server):
|
| 209 |
+
"""Test getting a token that is cancelled."""
|
| 210 |
+
# Setup mocks
|
| 211 |
+
oauth_server._initialize_server = AsyncMock()
|
| 212 |
+
oauth_server._stop_server = AsyncMock()
|
| 213 |
+
oauth_server.authenticator = MagicMock()
|
| 214 |
+
oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
|
| 215 |
+
|
| 216 |
+
with patch("webbrowser.open") as mock_open:
|
| 217 |
+
# Prepare token future with cancellation
|
| 218 |
+
oauth_server.token_future = asyncio.Future()
|
| 219 |
+
oauth_server.token_future.cancel()
|
| 220 |
+
|
| 221 |
+
# Test method
|
| 222 |
+
with pytest.raises(Exception, match="OAuth flow was cancelled"):
|
| 223 |
+
await oauth_server.get_token()
|
| 224 |
+
|
| 225 |
+
# Verify
|
| 226 |
+
oauth_server._initialize_server.assert_called_once()
|
| 227 |
+
oauth_server.authenticator.get_authorization_url.assert_called_once()
|
| 228 |
+
mock_open.assert_called_once_with("https://example.com/auth")
|
| 229 |
+
oauth_server._stop_server.assert_called_once()
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@pytest.mark.asyncio
|
| 233 |
+
async def test_get_token_exception(oauth_server):
|
| 234 |
+
"""Test getting a token with exception."""
|
| 235 |
+
# Setup mocks
|
| 236 |
+
oauth_server._initialize_server = AsyncMock()
|
| 237 |
+
oauth_server._stop_server = AsyncMock()
|
| 238 |
+
oauth_server.authenticator = MagicMock()
|
| 239 |
+
oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
|
| 240 |
+
|
| 241 |
+
with patch("webbrowser.open") as mock_open:
|
| 242 |
+
# Prepare token future with exception
|
| 243 |
+
oauth_server.token_future = asyncio.Future()
|
| 244 |
+
oauth_server.token_future.set_exception(Exception("Test error"))
|
| 245 |
+
|
| 246 |
+
# Test method
|
| 247 |
+
with pytest.raises(Exception, match="OAuth flow failed: Test error"):
|
| 248 |
+
await oauth_server.get_token()
|
| 249 |
+
|
| 250 |
+
# Verify
|
| 251 |
+
oauth_server._initialize_server.assert_called_once()
|
| 252 |
+
oauth_server.authenticator.get_authorization_url.assert_called_once()
|
| 253 |
+
mock_open.assert_called_once_with("https://example.com/auth")
|
| 254 |
+
oauth_server._stop_server.assert_called_once()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@pytest.mark.asyncio
|
| 258 |
+
async def test_get_refresh_token_from_oauth(client_credentials):
|
| 259 |
+
"""Test get_refresh_token_from_oauth function."""
|
| 260 |
+
with patch("strava_mcp.oauth_server.StravaOAuthServer") as MockOAuthServer:
|
| 261 |
+
# Setup mock
|
| 262 |
+
mock_server = MagicMock()
|
| 263 |
+
mock_server.get_token = AsyncMock(return_value="test_refresh_token")
|
| 264 |
+
MockOAuthServer.return_value = mock_server
|
| 265 |
+
|
| 266 |
+
# Test function
|
| 267 |
+
token = await get_refresh_token_from_oauth(
|
| 268 |
+
client_credentials["client_id"],
|
| 269 |
+
client_credentials["client_secret"]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Verify
|
| 273 |
+
assert token == "test_refresh_token"
|
| 274 |
+
MockOAuthServer.assert_called_once_with(
|
| 275 |
+
client_credentials["client_id"],
|
| 276 |
+
client_credentials["client_secret"]
|
| 277 |
+
)
|
| 278 |
+
mock_server.get_token.assert_called_once()
|
tests/test_server.py
CHANGED
|
@@ -1,10 +1,17 @@
|
|
| 1 |
-
from unittest.mock import AsyncMock, MagicMock
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from strava_mcp.models import Activity, DetailedActivity, SegmentEffort
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
class MockContext:
|
| 9 |
"""Mock MCP context for testing."""
|
| 10 |
|
|
@@ -167,4 +174,4 @@ async def test_get_activity_segments(mock_ctx, mock_service):
|
|
| 167 |
# Verify result
|
| 168 |
assert len(result) == 1
|
| 169 |
assert result[0]["id"] == mock_segment.id
|
| 170 |
-
assert result[0]["name"] == mock_segment.name
|
|
|
|
| 1 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
| 5 |
from strava_mcp.models import Activity, DetailedActivity, SegmentEffort
|
| 6 |
|
| 7 |
|
| 8 |
+
# Patch the StravaOAuthServer._run_server method to prevent coroutine warnings
|
| 9 |
+
# This must be at the module level before any imports that might create the coroutine
|
| 10 |
+
with patch("strava_mcp.oauth_server.StravaOAuthServer._run_server", new_callable=AsyncMock):
|
| 11 |
+
# Now imports will use the patched version
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
class MockContext:
|
| 16 |
"""Mock MCP context for testing."""
|
| 17 |
|
|
|
|
| 174 |
# Verify result
|
| 175 |
assert len(result) == 1
|
| 176 |
assert result[0]["id"] == mock_segment.id
|
| 177 |
+
assert result[0]["name"] == mock_segment.name
|
uv.lock
CHANGED
|
@@ -888,7 +888,7 @@ wheels = [
|
|
| 888 |
[[package]]
|
| 889 |
name = "strava"
|
| 890 |
version = "0.1.0"
|
| 891 |
-
source = {
|
| 892 |
dependencies = [
|
| 893 |
{ name = "fastapi" },
|
| 894 |
{ name = "httpx" },
|
|
|
|
| 888 |
[[package]]
|
| 889 |
name = "strava"
|
| 890 |
version = "0.1.0"
|
| 891 |
+
source = { editable = "." }
|
| 892 |
dependencies = [
|
| 893 |
{ name = "fastapi" },
|
| 894 |
{ name = "httpx" },
|