Yorrick Jansen commited on
Commit
313d1c4
·
1 Parent(s): 3975e9b

Iterate on ci/cd

Browse files
pyproject.toml CHANGED
@@ -28,6 +28,9 @@ dev = [
28
  requires = ["hatchling"]
29
  build-backend = "hatchling.build"
30
 
 
 
 
31
  [tool.ruff]
32
  line-length = 120
33
  target-version = "py313"
 
28
  requires = ["hatchling"]
29
  build-backend = "hatchling.build"
30
 
31
+ [tool.hatch.build.targets.wheel]
32
+ packages = ["strava_mcp"]
33
+
34
  [tool.ruff]
35
  line-length = 120
36
  target-version = "py313"
strava_mcp/config.py CHANGED
@@ -8,7 +8,7 @@ class StravaSettings(BaseSettings):
8
  client_id: str = Field(..., description="Strava API client ID")
9
  client_secret: str = Field(..., description="Strava API client secret")
10
  refresh_token: str | None = Field(
11
- None,
12
  description="Strava API refresh token (can be generated through auth flow)",
13
  )
14
  base_url: str = Field("https://www.strava.com/api/v3", description="Strava API base URL")
 
8
  client_id: str = Field(..., description="Strava API client ID")
9
  client_secret: str = Field(..., description="Strava API client secret")
10
  refresh_token: str | None = Field(
11
+ default=None,
12
  description="Strava API refresh token (can be generated through auth flow)",
13
  )
14
  base_url: str = Field("https://www.strava.com/api/v3", description="Strava API base URL")
tests/test_auth.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the Strava authentication module."""
2
+
3
+ import asyncio
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ import pytest
7
+ from fastapi import FastAPI
8
+ from fastapi.testclient import TestClient
9
+ from httpx import Response
10
+
11
+ from strava_mcp.auth import StravaAuthenticator, TokenResponse, get_strava_refresh_token
12
+
13
+
14
+ @pytest.fixture
15
+ def client_credentials():
16
+ """Fixture for client credentials."""
17
+ return {
18
+ "client_id": "test_client_id",
19
+ "client_secret": "test_client_secret",
20
+ }
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_token_response():
25
+ """Fixture for token response."""
26
+ return {
27
+ "access_token": "test_access_token",
28
+ "refresh_token": "test_refresh_token",
29
+ "expires_at": 1609459200,
30
+ "expires_in": 21600,
31
+ "token_type": "Bearer",
32
+ }
33
+
34
+
35
+ @pytest.fixture
36
+ def fastapi_app():
37
+ """Fixture for FastAPI app."""
38
+ return FastAPI()
39
+
40
+
41
+ @pytest.fixture
42
+ def authenticator(client_credentials, fastapi_app):
43
+ """Fixture for StravaAuthenticator."""
44
+ return StravaAuthenticator(
45
+ client_id=client_credentials["client_id"],
46
+ client_secret=client_credentials["client_secret"],
47
+ app=fastapi_app,
48
+ )
49
+
50
+
51
+ def test_get_authorization_url(authenticator):
52
+ """Test getting the authorization URL."""
53
+ url = authenticator.get_authorization_url()
54
+
55
+ # Check that the URL contains the expected parameters
56
+ assert "https://www.strava.com/oauth/authorize" in url
57
+ assert f"client_id={authenticator.client_id}" in url
58
+ # URL is encoded, so we need to check the non-encoded parts
59
+ assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A3008%2Fexchange_token" in url
60
+ assert "response_type=code" in url
61
+ assert "scope=" in url
62
+
63
+
64
+ def test_setup_routes(authenticator, fastapi_app):
65
+ """Test setting up routes."""
66
+ authenticator.setup_routes(fastapi_app)
67
+
68
+ # Check that the routes were added
69
+ routes = [route.path for route in fastapi_app.routes]
70
+ assert authenticator.redirect_path in routes
71
+ assert "/auth" in routes
72
+
73
+
74
+ def test_setup_routes_no_app(authenticator):
75
+ """Test setting up routes with no app."""
76
+ authenticator.app = None
77
+ with pytest.raises(ValueError, match="No FastAPI app provided"):
78
+ authenticator.setup_routes()
79
+
80
+
81
+ @pytest.mark.asyncio
82
+ async def test_exchange_token_success(authenticator, mock_token_response):
83
+ """Test exchanging token successfully."""
84
+ # Setup mock
85
+ with patch("httpx.AsyncClient") as mock_client:
86
+ mock_response = MagicMock(spec=Response)
87
+ mock_response.status_code = 200
88
+ mock_response.json.return_value = mock_token_response
89
+ mock_client.return_value.__aenter__.return_value.post.return_value = mock_response
90
+
91
+ # Set up a future to receive the token
92
+ authenticator.token_future = asyncio.Future()
93
+
94
+ # Call the handler
95
+ response = await authenticator.exchange_token(code="test_code")
96
+
97
+ # Check response
98
+ assert response.status_code == 200
99
+ assert "Authorization successful" in response.body.decode()
100
+
101
+ # Check token future
102
+ assert authenticator.token_future.done()
103
+ assert await authenticator.token_future == "test_refresh_token"
104
+
105
+ # Check token was saved
106
+ assert authenticator.refresh_token == "test_refresh_token"
107
+
108
+ # Verify correct API call
109
+ mock_client.return_value.__aenter__.return_value.post.assert_called_once()
110
+ args, kwargs = mock_client.return_value.__aenter__.return_value.post.call_args
111
+ assert args[0] == "https://www.strava.com/oauth/token"
112
+ assert kwargs["data"]["client_id"] == authenticator.client_id
113
+ assert kwargs["data"]["client_secret"] == authenticator.client_secret
114
+ assert kwargs["data"]["code"] == "test_code"
115
+ assert kwargs["data"]["grant_type"] == "authorization_code"
116
+
117
+
118
+ @pytest.mark.asyncio
119
+ async def test_exchange_token_failure(authenticator):
120
+ """Test exchanging token with failure."""
121
+ # Setup mock
122
+ with patch("httpx.AsyncClient") as mock_client:
123
+ mock_response = MagicMock(spec=Response)
124
+ mock_response.status_code = 400
125
+ mock_response.text = "Invalid code"
126
+ mock_client.return_value.__aenter__.return_value.post.return_value = mock_response
127
+
128
+ # Set up a future to receive the token
129
+ authenticator.token_future = asyncio.Future()
130
+
131
+ # Call the handler
132
+ response = await authenticator.exchange_token(code="invalid_code")
133
+
134
+ # Check response
135
+ assert response.status_code == 200
136
+ assert "Authorization failed" in response.body.decode()
137
+
138
+ # Check token future
139
+ assert authenticator.token_future.done()
140
+ with pytest.raises(Exception):
141
+ await authenticator.token_future
142
+
143
+
144
+ @pytest.mark.asyncio
145
+ async def test_start_auth_flow(authenticator):
146
+ """Test starting auth flow."""
147
+ with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
148
+ response = await authenticator.start_auth_flow()
149
+ assert response.status_code == 307
150
+ assert response.headers["location"] == "https://example.com/auth"
151
+
152
+
153
+ @pytest.mark.asyncio
154
+ async def test_get_refresh_token(authenticator):
155
+ """Test getting refresh token."""
156
+ # Mock the webbrowser.open call
157
+ with patch("webbrowser.open", return_value=True) as mock_open:
158
+ with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
159
+ # Set the future result after a delay
160
+ authenticator.token_future = None # Reset it so a new one is created
161
+
162
+ # Start the token request in background
163
+ task = asyncio.create_task(authenticator.get_refresh_token())
164
+
165
+ # Wait a bit and set the result
166
+ await asyncio.sleep(0.1)
167
+ authenticator.token_future.set_result("test_refresh_token")
168
+
169
+ # Get the result
170
+ token = await task
171
+
172
+ # Verify
173
+ assert token == "test_refresh_token"
174
+ mock_open.assert_called_once_with("https://example.com/auth")
175
+
176
+
177
+ @pytest.mark.asyncio
178
+ async def test_get_refresh_token_no_browser(authenticator):
179
+ """Test getting refresh token without opening browser."""
180
+ with patch("webbrowser.open") as mock_open:
181
+ with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
182
+ # Set the future result after a delay
183
+ authenticator.token_future = None # Reset it so a new one is created
184
+
185
+ # Start the token request in background
186
+ task = asyncio.create_task(authenticator.get_refresh_token(open_browser=False))
187
+
188
+ # Wait a bit and set the result
189
+ await asyncio.sleep(0.1)
190
+ authenticator.token_future.set_result("test_refresh_token")
191
+
192
+ # Get the result
193
+ token = await task
194
+
195
+ # Verify
196
+ assert token == "test_refresh_token"
197
+ mock_open.assert_not_called()
198
+
199
+
200
+ @pytest.mark.asyncio
201
+ async def test_get_refresh_token_browser_fails(authenticator):
202
+ """Test getting refresh token with browser opening failing."""
203
+ with patch("webbrowser.open", return_value=False) as mock_open:
204
+ with patch.object(authenticator, "get_authorization_url", return_value="https://example.com/auth"):
205
+ # Set the future result after a delay
206
+ authenticator.token_future = None # Reset it so a new one is created
207
+
208
+ # Start the token request in background
209
+ task = asyncio.create_task(authenticator.get_refresh_token())
210
+
211
+ # Wait a bit and set the result
212
+ await asyncio.sleep(0.1)
213
+ authenticator.token_future.set_result("test_refresh_token")
214
+
215
+ # Get the result
216
+ token = await task
217
+
218
+ # Verify
219
+ assert token == "test_refresh_token"
220
+ mock_open.assert_called_once_with("https://example.com/auth")
221
+
222
+
223
+ @pytest.mark.asyncio
224
+ async def test_get_strava_refresh_token(client_credentials):
225
+ """Test get_strava_refresh_token function."""
226
+ with patch("strava_mcp.auth.StravaAuthenticator") as MockAuthenticator:
227
+ # Setup mock
228
+ mock_authenticator = MagicMock()
229
+ mock_authenticator.get_refresh_token = AsyncMock(return_value="test_refresh_token")
230
+ mock_authenticator.setup_routes = MagicMock()
231
+ MockAuthenticator.return_value = mock_authenticator
232
+
233
+ # Test without app
234
+ token = await get_strava_refresh_token(
235
+ client_credentials["client_id"],
236
+ client_credentials["client_secret"]
237
+ )
238
+
239
+ # Verify
240
+ assert token == "test_refresh_token"
241
+ MockAuthenticator.assert_called_once_with(
242
+ client_credentials["client_id"],
243
+ client_credentials["client_secret"],
244
+ None
245
+ )
246
+ mock_authenticator.setup_routes.assert_not_called()
247
+
248
+ # Reset mocks
249
+ MockAuthenticator.reset_mock()
250
+ mock_authenticator.get_refresh_token.reset_mock()
251
+ mock_authenticator.setup_routes.reset_mock()
252
+
253
+ # Test with app
254
+ app = FastAPI()
255
+ token = await get_strava_refresh_token(
256
+ client_credentials["client_id"],
257
+ client_credentials["client_secret"],
258
+ app
259
+ )
260
+
261
+ # Verify
262
+ assert token == "test_refresh_token"
263
+ MockAuthenticator.assert_called_once_with(
264
+ client_credentials["client_id"],
265
+ client_credentials["client_secret"],
266
+ app
267
+ )
268
+ mock_authenticator.setup_routes.assert_called_once_with(app)
tests/test_config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for configuration module."""
2
+
3
+ import os
4
+ from unittest import mock
5
+
6
+ import pytest
7
+
8
+ from strava_mcp.config import StravaSettings
9
+
10
+
11
+ def test_strava_settings_defaults():
12
+ """Test default settings for StravaSettings."""
13
+ # Use required parameters only
14
+ with mock.patch.dict(os.environ, {}, clear=True):
15
+ settings = StravaSettings(
16
+ client_id="test_client_id",
17
+ client_secret="test_client_secret",
18
+ )
19
+
20
+ assert settings.client_id == "test_client_id"
21
+ assert settings.client_secret == "test_client_secret"
22
+ assert settings.refresh_token is None
23
+ assert settings.base_url == "https://www.strava.com/api/v3"
24
+
25
+
26
+ def test_strava_settings_from_env():
27
+ """Test loading settings from environment variables."""
28
+ with mock.patch.dict(
29
+ os.environ,
30
+ {
31
+ "STRAVA_CLIENT_ID": "env_client_id",
32
+ "STRAVA_CLIENT_SECRET": "env_client_secret",
33
+ "STRAVA_REFRESH_TOKEN": "env_refresh_token",
34
+ "STRAVA_BASE_URL": "https://custom.strava.api/v3",
35
+ },
36
+ ):
37
+ settings = StravaSettings()
38
+
39
+ assert settings.client_id == "env_client_id"
40
+ assert settings.client_secret == "env_client_secret"
41
+ assert settings.refresh_token == "env_refresh_token"
42
+ assert settings.base_url == "https://custom.strava.api/v3"
43
+
44
+
45
+ def test_strava_settings_override():
46
+ """Test overriding environment settings with direct values."""
47
+ with mock.patch.dict(
48
+ os.environ,
49
+ {
50
+ "STRAVA_CLIENT_ID": "env_client_id",
51
+ "STRAVA_CLIENT_SECRET": "env_client_secret",
52
+ "STRAVA_REFRESH_TOKEN": "env_refresh_token",
53
+ },
54
+ ):
55
+ settings = StravaSettings(
56
+ client_id="direct_client_id",
57
+ refresh_token="direct_refresh_token",
58
+ )
59
+
60
+ # Direct values should override environment variables
61
+ assert settings.client_id == "direct_client_id"
62
+ assert settings.client_secret == "env_client_secret"
63
+ assert settings.refresh_token == "direct_refresh_token"
64
+
65
+
66
+ def test_strava_settings_model_config():
67
+ """Test model configuration for StravaSettings."""
68
+ assert StravaSettings.model_config["env_prefix"] == "STRAVA_"
69
+ assert StravaSettings.model_config["env_file"] == ".env"
70
+ assert StravaSettings.model_config["env_file_encoding"] == "utf-8"
tests/test_models.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the Strava models."""
2
+
3
+ import pytest
4
+ from datetime import datetime
5
+ from pydantic import ValidationError
6
+
7
+ from strava_mcp.models import Activity, DetailedActivity, Segment, SegmentEffort, ErrorResponse
8
+
9
+
10
+ @pytest.fixture
11
+ def activity_data():
12
+ """Fixture with valid activity data."""
13
+ return {
14
+ "id": 1234567890,
15
+ "name": "Morning Run",
16
+ "distance": 5000.0,
17
+ "moving_time": 1200,
18
+ "elapsed_time": 1300,
19
+ "total_elevation_gain": 50.0,
20
+ "type": "Run",
21
+ "sport_type": "Run",
22
+ "start_date": "2023-01-01T10:00:00Z",
23
+ "start_date_local": "2023-01-01T10:00:00Z",
24
+ "timezone": "Europe/London",
25
+ "achievement_count": 2,
26
+ "kudos_count": 5,
27
+ "comment_count": 0,
28
+ "athlete_count": 1,
29
+ "photo_count": 0,
30
+ "trainer": False,
31
+ "commute": False,
32
+ "manual": False,
33
+ "private": False,
34
+ "flagged": False,
35
+ "average_speed": 4.167,
36
+ "max_speed": 5.3,
37
+ "has_heartrate": True,
38
+ "average_heartrate": 140.0,
39
+ "max_heartrate": 160.0,
40
+ }
41
+
42
+
43
+ @pytest.fixture
44
+ def detailed_activity_data(activity_data):
45
+ """Fixture with valid detailed activity data."""
46
+ return {
47
+ **activity_data,
48
+ "description": "Test description",
49
+ "athlete": {"id": 123},
50
+ "calories": 500.0,
51
+ }
52
+
53
+
54
+ @pytest.fixture
55
+ def segment_data():
56
+ """Fixture with valid segment data."""
57
+ return {
58
+ "id": 12345,
59
+ "name": "Test Segment",
60
+ "activity_type": "Run",
61
+ "distance": 1000.0,
62
+ "average_grade": 5.0,
63
+ "maximum_grade": 10.0,
64
+ "elevation_high": 200.0,
65
+ "elevation_low": 150.0,
66
+ "total_elevation_gain": 50.0,
67
+ "start_latlng": [51.5, -0.1],
68
+ "end_latlng": [51.5, -0.2],
69
+ "climb_category": 0,
70
+ "private": False,
71
+ "starred": False,
72
+ }
73
+
74
+
75
+ @pytest.fixture
76
+ def segment_effort_data(segment_data):
77
+ """Fixture with valid segment effort data."""
78
+ return {
79
+ "id": 67890,
80
+ "activity_id": 1234567890,
81
+ "segment_id": 12345,
82
+ "name": "Test Segment",
83
+ "elapsed_time": 180,
84
+ "moving_time": 180,
85
+ "start_date": "2023-01-01T10:05:00Z",
86
+ "start_date_local": "2023-01-01T10:05:00Z",
87
+ "distance": 1000.0,
88
+ "athlete": {"id": 123},
89
+ "segment": segment_data,
90
+ }
91
+
92
+
93
+ def test_activity_model(activity_data):
94
+ """Test the Activity model."""
95
+ activity = Activity(**activity_data)
96
+
97
+ assert activity.id == activity_data["id"]
98
+ assert activity.name == activity_data["name"]
99
+ assert activity.distance == activity_data["distance"]
100
+ assert activity.start_date == datetime.fromisoformat(activity_data["start_date"].replace("Z", "+00:00"))
101
+ assert activity.start_date_local == datetime.fromisoformat(activity_data["start_date_local"].replace("Z", "+00:00"))
102
+ assert activity.average_heartrate == activity_data["average_heartrate"]
103
+ assert activity.max_heartrate == activity_data["max_heartrate"]
104
+
105
+
106
+ def test_activity_model_optional_fields(activity_data):
107
+ """Test the Activity model with optional fields."""
108
+ # Remove some optional fields
109
+ data = activity_data.copy()
110
+ data.pop("average_heartrate")
111
+ data.pop("max_heartrate")
112
+
113
+ activity = Activity(**data)
114
+
115
+ assert activity.average_heartrate is None
116
+ assert activity.max_heartrate is None
117
+
118
+
119
+ def test_activity_model_missing_required_fields(activity_data):
120
+ """Test the Activity model with missing required fields."""
121
+ data = activity_data.copy()
122
+ data.pop("id") # Remove a required field
123
+
124
+ with pytest.raises(ValidationError):
125
+ Activity(**data)
126
+
127
+
128
+ def test_detailed_activity_model(detailed_activity_data):
129
+ """Test the DetailedActivity model."""
130
+ activity = DetailedActivity(**detailed_activity_data)
131
+
132
+ assert activity.id == detailed_activity_data["id"]
133
+ assert activity.name == detailed_activity_data["name"]
134
+ assert activity.description == detailed_activity_data["description"]
135
+ assert activity.athlete == detailed_activity_data["athlete"]
136
+ assert activity.calories == detailed_activity_data["calories"]
137
+
138
+
139
+ def test_detailed_activity_optional_fields(detailed_activity_data):
140
+ """Test the DetailedActivity model with optional fields."""
141
+ data = detailed_activity_data.copy()
142
+ data.pop("description")
143
+ data.pop("calories")
144
+
145
+ activity = DetailedActivity(**data)
146
+
147
+ assert activity.description is None
148
+ assert activity.calories is None
149
+
150
+
151
+ def test_segment_model(segment_data):
152
+ """Test the Segment model."""
153
+ segment = Segment(**segment_data)
154
+
155
+ assert segment.id == segment_data["id"]
156
+ assert segment.name == segment_data["name"]
157
+ assert segment.activity_type == segment_data["activity_type"]
158
+ assert segment.distance == segment_data["distance"]
159
+ assert segment.start_latlng == segment_data["start_latlng"]
160
+ assert segment.end_latlng == segment_data["end_latlng"]
161
+
162
+
163
+ def test_segment_optional_fields(segment_data):
164
+ """Test the Segment model with optional fields."""
165
+ # Add some optional fields
166
+ data = segment_data.copy()
167
+ data["city"] = "London"
168
+ data["state"] = "Greater London"
169
+ data["country"] = "United Kingdom"
170
+
171
+ segment = Segment(**data)
172
+
173
+ assert segment.city == "London"
174
+ assert segment.state == "Greater London"
175
+ assert segment.country == "United Kingdom"
176
+
177
+
178
+ def test_segment_missing_fields(segment_data):
179
+ """Test the Segment model with missing required fields."""
180
+ data = segment_data.copy()
181
+ data.pop("id") # Remove a required field
182
+
183
+ with pytest.raises(ValidationError):
184
+ Segment(**data)
185
+
186
+
187
+ def test_segment_effort_model(segment_effort_data):
188
+ """Test the SegmentEffort model."""
189
+ effort = SegmentEffort(**segment_effort_data)
190
+
191
+ assert effort.id == segment_effort_data["id"]
192
+ assert effort.activity_id == segment_effort_data["activity_id"]
193
+ assert effort.segment_id == segment_effort_data["segment_id"]
194
+ assert effort.name == segment_effort_data["name"]
195
+ assert effort.elapsed_time == segment_effort_data["elapsed_time"]
196
+ assert effort.moving_time == segment_effort_data["moving_time"]
197
+ assert effort.start_date == datetime.fromisoformat(segment_effort_data["start_date"].replace("Z", "+00:00"))
198
+ assert effort.start_date_local == datetime.fromisoformat(segment_effort_data["start_date_local"].replace("Z", "+00:00"))
199
+ assert effort.distance == segment_effort_data["distance"]
200
+ assert effort.athlete == segment_effort_data["athlete"]
201
+
202
+ # Test nested segment object
203
+ assert effort.segment.id == segment_effort_data["segment"]["id"]
204
+ assert effort.segment.name == segment_effort_data["segment"]["name"]
205
+
206
+
207
+ def test_segment_effort_optional_fields(segment_effort_data):
208
+ """Test the SegmentEffort model with optional fields."""
209
+ # Add some optional fields
210
+ data = segment_effort_data.copy()
211
+ data["average_watts"] = 200.0
212
+ data["device_watts"] = True
213
+ data["average_heartrate"] = 150.0
214
+ data["max_heartrate"] = 170.0
215
+ data["pr_rank"] = 1
216
+ data["achievements"] = [{"type": "overall", "rank": 1}]
217
+
218
+ effort = SegmentEffort(**data)
219
+
220
+ assert effort.average_watts == 200.0
221
+ assert effort.device_watts is True
222
+ assert effort.average_heartrate == 150.0
223
+ assert effort.max_heartrate == 170.0
224
+ assert effort.pr_rank == 1
225
+ assert effort.achievements == [{"type": "overall", "rank": 1}]
226
+
227
+
228
+ def test_segment_effort_missing_fields(segment_effort_data):
229
+ """Test the SegmentEffort model with missing required fields."""
230
+ data = segment_effort_data.copy()
231
+ data.pop("segment") # Remove a required field
232
+
233
+ with pytest.raises(ValidationError):
234
+ SegmentEffort(**data)
235
+
236
+
237
+ def test_error_response():
238
+ """Test the ErrorResponse model."""
239
+ data = {
240
+ "message": "Resource not found",
241
+ "code": 404
242
+ }
243
+
244
+ error = ErrorResponse(**data)
245
+
246
+ assert error.message == "Resource not found"
247
+ assert error.code == 404
tests/test_oauth_server.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the Strava OAuth server module."""
2
+
3
+ import asyncio
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ import pytest
7
+ import uvicorn
8
+ from fastapi import FastAPI
9
+
10
+ from strava_mcp.auth import StravaAuthenticator
11
+ from strava_mcp.oauth_server import StravaOAuthServer, get_refresh_token_from_oauth
12
+
13
+
14
+ @pytest.fixture
15
+ def client_credentials():
16
+ """Fixture for client credentials."""
17
+ return {
18
+ "client_id": "test_client_id",
19
+ "client_secret": "test_client_secret",
20
+ }
21
+
22
+
23
+ @pytest.fixture
24
+ def oauth_server(client_credentials):
25
+ """Fixture for StravaOAuthServer."""
26
+ return StravaOAuthServer(
27
+ client_id=client_credentials["client_id"],
28
+ client_secret=client_credentials["client_secret"],
29
+ )
30
+
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_initialize_server(oauth_server):
34
+ """Test initializing the server."""
35
+ # Mock the OAuth server's dependencies directly
36
+ with patch("strava_mcp.oauth_server.StravaAuthenticator") as MockAuthenticator:
37
+ with patch("asyncio.create_task") as mock_create_task:
38
+ # Setup mocks
39
+ mock_authenticator = MagicMock()
40
+ MockAuthenticator.return_value = mock_authenticator
41
+ mock_task = MagicMock()
42
+ mock_create_task.return_value = mock_task
43
+
44
+ # Test method
45
+ await oauth_server._initialize_server()
46
+
47
+ # Verify FastAPI app was created
48
+ assert oauth_server.app is not None
49
+ assert oauth_server.app.title == "Strava OAuth"
50
+
51
+ # Verify authenticator was created and configured
52
+ MockAuthenticator.assert_called_once_with(
53
+ client_id=oauth_server.client_id,
54
+ client_secret=oauth_server.client_secret,
55
+ app=oauth_server.app,
56
+ host=oauth_server.host,
57
+ port=oauth_server.port,
58
+ )
59
+ assert oauth_server.authenticator == mock_authenticator
60
+
61
+ # Verify token future was stored in authenticator
62
+ assert mock_authenticator.token_future is oauth_server.token_future
63
+
64
+ # Verify routes were set up
65
+ mock_authenticator.setup_routes.assert_called_once_with(oauth_server.app)
66
+
67
+ # Verify server task was created
68
+ mock_create_task.assert_called_once()
69
+ assert oauth_server.server_task == mock_task
70
+
71
+
72
+ @pytest.mark.asyncio
73
+ async def test_run_server(oauth_server):
74
+ """Test running the server."""
75
+ with patch("uvicorn.Server") as MockServer:
76
+ with patch("uvicorn.Config") as MockConfig:
77
+ # Setup mocks
78
+ mock_server = AsyncMock()
79
+ MockServer.return_value = mock_server
80
+ mock_config = MagicMock()
81
+ MockConfig.return_value = mock_config
82
+
83
+ # Create app
84
+ oauth_server.app = FastAPI()
85
+
86
+ # Test method
87
+ await oauth_server._run_server()
88
+
89
+ # Verify config was created correctly
90
+ MockConfig.assert_called_once_with(
91
+ app=oauth_server.app,
92
+ host=oauth_server.host,
93
+ port=oauth_server.port,
94
+ log_level="info",
95
+ )
96
+
97
+ # Verify server was created and run
98
+ MockServer.assert_called_once_with(mock_config)
99
+ mock_server.serve.assert_called_once()
100
+ assert oauth_server.server == mock_server
101
+
102
+
103
+ @pytest.mark.asyncio
104
+ async def test_run_server_exception(oauth_server):
105
+ """Test running the server with an exception."""
106
+ with patch("uvicorn.Server") as MockServer:
107
+ with patch("uvicorn.Config") as MockConfig:
108
+ # Setup mocks
109
+ mock_server = AsyncMock()
110
+ mock_server.serve = AsyncMock(side_effect=Exception("Test error"))
111
+ MockServer.return_value = mock_server
112
+ mock_config = MagicMock()
113
+ MockConfig.return_value = mock_config
114
+
115
+ # Create app and token future
116
+ oauth_server.app = FastAPI()
117
+ oauth_server.token_future = asyncio.Future()
118
+
119
+ # Test method
120
+ await oauth_server._run_server()
121
+
122
+ # Verify token future has exception
123
+ assert oauth_server.token_future.done()
124
+ with pytest.raises(Exception, match="Test error"):
125
+ await oauth_server.token_future
126
+
127
+
128
+ @pytest.mark.asyncio
129
+ async def test_stop_server(oauth_server):
130
+ """Test stopping the server."""
131
+ # Setup server and task
132
+ oauth_server.server = MagicMock()
133
+ oauth_server.server_task = MagicMock()
134
+ oauth_server.server_task.done = MagicMock(return_value=False)
135
+
136
+ # Make asyncio.wait_for return immediately
137
+ with patch("asyncio.wait_for", new=AsyncMock()) as mock_wait_for:
138
+ # Test method
139
+ await oauth_server._stop_server()
140
+
141
+ # Verify server was stopped
142
+ assert oauth_server.server.should_exit is True
143
+ mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0)
144
+
145
+
146
+ @pytest.mark.asyncio
147
+ async def test_stop_server_timeout(oauth_server):
148
+ """Test stopping the server with timeout."""
149
+ # Setup server and task
150
+ oauth_server.server = MagicMock()
151
+ oauth_server.server_task = MagicMock()
152
+ oauth_server.server_task.done = MagicMock(return_value=False)
153
+
154
+ # Make asyncio.wait_for raise TimeoutError
155
+ with patch("asyncio.wait_for", new=AsyncMock(side_effect=TimeoutError())) as mock_wait_for:
156
+ # Test method
157
+ await oauth_server._stop_server()
158
+
159
+ # Verify server was stopped
160
+ assert oauth_server.server.should_exit is True
161
+ mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0)
162
+
163
+
164
+ @pytest.mark.asyncio
165
+ async def test_get_token(oauth_server):
166
+ """Test getting a token."""
167
+ # Setup mocks
168
+ oauth_server._initialize_server = AsyncMock()
169
+ oauth_server._stop_server = AsyncMock()
170
+ oauth_server.authenticator = MagicMock()
171
+ oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
172
+
173
+ with patch("webbrowser.open") as mock_open:
174
+ # Prepare token future
175
+ oauth_server.token_future = asyncio.Future()
176
+ oauth_server.token_future.set_result("test_refresh_token")
177
+
178
+ # Test method
179
+ token = await oauth_server.get_token()
180
+
181
+ # Verify
182
+ assert token == "test_refresh_token"
183
+ oauth_server._initialize_server.assert_called_once()
184
+ oauth_server.authenticator.get_authorization_url.assert_called_once()
185
+ mock_open.assert_called_once_with("https://example.com/auth")
186
+ oauth_server._stop_server.assert_called_once()
187
+
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_get_token_no_authenticator(oauth_server):
191
+ """Test getting a token with no authenticator."""
192
+ # Setup mocks
193
+ oauth_server._initialize_server = AsyncMock()
194
+ oauth_server._stop_server = AsyncMock()
195
+ oauth_server.authenticator = None
196
+
197
+ # Test method
198
+ with pytest.raises(Exception, match="Authenticator not initialized"):
199
+ await oauth_server.get_token()
200
+
201
+ # Verify
202
+ oauth_server._initialize_server.assert_called_once()
203
+ # The stop server is not called because we exit with exception before getting there
204
+ # oauth_server._stop_server.assert_called_once()
205
+
206
+
207
+ @pytest.mark.asyncio
208
+ async def test_get_token_cancelled(oauth_server):
209
+ """Test getting a token that is cancelled."""
210
+ # Setup mocks
211
+ oauth_server._initialize_server = AsyncMock()
212
+ oauth_server._stop_server = AsyncMock()
213
+ oauth_server.authenticator = MagicMock()
214
+ oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
215
+
216
+ with patch("webbrowser.open") as mock_open:
217
+ # Prepare token future with cancellation
218
+ oauth_server.token_future = asyncio.Future()
219
+ oauth_server.token_future.cancel()
220
+
221
+ # Test method
222
+ with pytest.raises(Exception, match="OAuth flow was cancelled"):
223
+ await oauth_server.get_token()
224
+
225
+ # Verify
226
+ oauth_server._initialize_server.assert_called_once()
227
+ oauth_server.authenticator.get_authorization_url.assert_called_once()
228
+ mock_open.assert_called_once_with("https://example.com/auth")
229
+ oauth_server._stop_server.assert_called_once()
230
+
231
+
232
+ @pytest.mark.asyncio
233
+ async def test_get_token_exception(oauth_server):
234
+ """Test getting a token with exception."""
235
+ # Setup mocks
236
+ oauth_server._initialize_server = AsyncMock()
237
+ oauth_server._stop_server = AsyncMock()
238
+ oauth_server.authenticator = MagicMock()
239
+ oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth")
240
+
241
+ with patch("webbrowser.open") as mock_open:
242
+ # Prepare token future with exception
243
+ oauth_server.token_future = asyncio.Future()
244
+ oauth_server.token_future.set_exception(Exception("Test error"))
245
+
246
+ # Test method
247
+ with pytest.raises(Exception, match="OAuth flow failed: Test error"):
248
+ await oauth_server.get_token()
249
+
250
+ # Verify
251
+ oauth_server._initialize_server.assert_called_once()
252
+ oauth_server.authenticator.get_authorization_url.assert_called_once()
253
+ mock_open.assert_called_once_with("https://example.com/auth")
254
+ oauth_server._stop_server.assert_called_once()
255
+
256
+
257
+ @pytest.mark.asyncio
258
+ async def test_get_refresh_token_from_oauth(client_credentials):
259
+ """Test get_refresh_token_from_oauth function."""
260
+ with patch("strava_mcp.oauth_server.StravaOAuthServer") as MockOAuthServer:
261
+ # Setup mock
262
+ mock_server = MagicMock()
263
+ mock_server.get_token = AsyncMock(return_value="test_refresh_token")
264
+ MockOAuthServer.return_value = mock_server
265
+
266
+ # Test function
267
+ token = await get_refresh_token_from_oauth(
268
+ client_credentials["client_id"],
269
+ client_credentials["client_secret"]
270
+ )
271
+
272
+ # Verify
273
+ assert token == "test_refresh_token"
274
+ MockOAuthServer.assert_called_once_with(
275
+ client_credentials["client_id"],
276
+ client_credentials["client_secret"]
277
+ )
278
+ mock_server.get_token.assert_called_once()
tests/test_server.py CHANGED
@@ -1,10 +1,17 @@
1
- from unittest.mock import AsyncMock, MagicMock
2
 
3
  import pytest
4
 
5
  from strava_mcp.models import Activity, DetailedActivity, SegmentEffort
6
 
7
 
 
 
 
 
 
 
 
8
  class MockContext:
9
  """Mock MCP context for testing."""
10
 
@@ -167,4 +174,4 @@ async def test_get_activity_segments(mock_ctx, mock_service):
167
  # Verify result
168
  assert len(result) == 1
169
  assert result[0]["id"] == mock_segment.id
170
- assert result[0]["name"] == mock_segment.name
 
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
 
3
  import pytest
4
 
5
  from strava_mcp.models import Activity, DetailedActivity, SegmentEffort
6
 
7
 
8
+ # Patch the StravaOAuthServer._run_server method to prevent coroutine warnings
9
+ # This must be at the module level before any imports that might create the coroutine
10
+ with patch("strava_mcp.oauth_server.StravaOAuthServer._run_server", new_callable=AsyncMock):
11
+ # Now imports will use the patched version
12
+ pass
13
+
14
+
15
  class MockContext:
16
  """Mock MCP context for testing."""
17
 
 
174
  # Verify result
175
  assert len(result) == 1
176
  assert result[0]["id"] == mock_segment.id
177
+ assert result[0]["name"] == mock_segment.name
uv.lock CHANGED
@@ -888,7 +888,7 @@ wheels = [
888
  [[package]]
889
  name = "strava"
890
  version = "0.1.0"
891
- source = { virtual = "." }
892
  dependencies = [
893
  { name = "fastapi" },
894
  { name = "httpx" },
 
888
  [[package]]
889
  name = "strava"
890
  version = "0.1.0"
891
+ source = { editable = "." }
892
  dependencies = [
893
  { name = "fastapi" },
894
  { name = "httpx" },