NeerajCodz commited on
Commit
e8d7c11
·
1 Parent(s): 864b733

test: add comprehensive API and core module tests

Browse files

- Add 30+ new tests for memory, tasks, episode modules
- Episode tests cover lifecycle, steps, manager operations
- Memory API tests cover store, query, delete operations
- Tasks API tests cover list, filter, create operations
- All 101 tests passing with 44% coverage

backend/.coverage ADDED
Binary file (53.2 kB). View file
 
backend/app/agents/__pycache__/coordinator.cpython-314.pyc CHANGED
Binary files a/backend/app/agents/__pycache__/coordinator.cpython-314.pyc and b/backend/app/agents/__pycache__/coordinator.cpython-314.pyc differ
 
backend/app/agents/__pycache__/memory_agent.cpython-314.pyc CHANGED
Binary files a/backend/app/agents/__pycache__/memory_agent.cpython-314.pyc and b/backend/app/agents/__pycache__/memory_agent.cpython-314.pyc differ
 
backend/app/api/routes/__pycache__/agents.cpython-314.pyc CHANGED
Binary files a/backend/app/api/routes/__pycache__/agents.cpython-314.pyc and b/backend/app/api/routes/__pycache__/agents.cpython-314.pyc differ
 
backend/app/memory/__pycache__/long_term.cpython-314.pyc CHANGED
Binary files a/backend/app/memory/__pycache__/long_term.cpython-314.pyc and b/backend/app/memory/__pycache__/long_term.cpython-314.pyc differ
 
backend/app/memory/__pycache__/short_term.cpython-314.pyc CHANGED
Binary files a/backend/app/memory/__pycache__/short_term.cpython-314.pyc and b/backend/app/memory/__pycache__/short_term.cpython-314.pyc differ
 
backend/app/memory/__pycache__/working.cpython-314.pyc CHANGED
Binary files a/backend/app/memory/__pycache__/working.cpython-314.pyc and b/backend/app/memory/__pycache__/working.cpython-314.pyc differ
 
backend/app/models/__pycache__/router.cpython-314.pyc CHANGED
Binary files a/backend/app/models/__pycache__/router.cpython-314.pyc and b/backend/app/models/__pycache__/router.cpython-314.pyc differ
 
backend/tests/test_api/test_memory.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for memory API routes."""
2
+
3
+ import pytest
4
+ from fastapi.testclient import TestClient
5
+
6
+
7
+ class TestMemoryAPI:
8
+ """Test memory API endpoints."""
9
+
10
+ def test_store_memory_entry(self, client: TestClient) -> None:
11
+ """Test POST /api/memory/store creates new memory entry."""
12
+ payload = {
13
+ "memory_type": "short_term",
14
+ "content": {
15
+ "observation": "User clicked login button",
16
+ "action": "click",
17
+ },
18
+ "metadata": {"url": "https://example.com"},
19
+ "episode_id": "ep_001",
20
+ "agent_id": "agent_test",
21
+ }
22
+
23
+ response = client.post("/api/memory/store", json=payload)
24
+
25
+ assert response.status_code == 201
26
+ data = response.json()
27
+
28
+ assert "id" in data
29
+ assert data["memory_type"] == "short_term"
30
+ assert data["content"] == payload["content"]
31
+ assert data["episode_id"] == "ep_001"
32
+ assert "timestamp" in data
33
+
34
+ def test_store_memory_all_types(self, client: TestClient) -> None:
35
+ """Test storing memory with all valid types."""
36
+ valid_types = ["short_term", "working", "long_term", "shared"]
37
+
38
+ for memory_type in valid_types:
39
+ payload = {
40
+ "memory_type": memory_type,
41
+ "content": {"test": f"data for {memory_type}"},
42
+ }
43
+
44
+ response = client.post("/api/memory/store", json=payload)
45
+ assert response.status_code == 201
46
+ data = response.json()
47
+ assert data["memory_type"] == memory_type
48
+
49
+ def test_store_memory_invalid_type(self, client: TestClient) -> None:
50
+ """Test storing memory with invalid type."""
51
+ payload = {"memory_type": "invalid_type", "content": {"test": "data"}}
52
+
53
+ response = client.post("/api/memory/store", json=payload)
54
+ assert response.status_code == 422
55
+
56
+ def test_get_memory_entry(self, client: TestClient) -> None:
57
+ """Test GET /api/memory/{entry_id}."""
58
+ # Store first
59
+ payload = {
60
+ "memory_type": "long_term",
61
+ "content": {"knowledge": "test data"},
62
+ }
63
+
64
+ store_response = client.post("/api/memory/store", json=payload)
65
+ assert store_response.status_code == 201
66
+ entry_id = store_response.json()["id"]
67
+
68
+ # Retrieve
69
+ response = client.get(f"/api/memory/{entry_id}")
70
+ assert response.status_code == 200
71
+ data = response.json()
72
+ assert data["id"] == entry_id
73
+ assert data["memory_type"] == "long_term"
74
+
75
+ def test_get_nonexistent_memory(self, client: TestClient) -> None:
76
+ """Test GET /api/memory/{entry_id} for non-existent."""
77
+ response = client.get("/api/memory/nonexistent-id-12345")
78
+ assert response.status_code == 404
79
+
80
+ def test_delete_memory_entry(self, client: TestClient) -> None:
81
+ """Test DELETE /api/memory/{entry_id}."""
82
+ # Store first
83
+ payload = {
84
+ "memory_type": "short_term",
85
+ "content": {"temporary": "data"},
86
+ }
87
+
88
+ store_response = client.post("/api/memory/store", json=payload)
89
+ assert store_response.status_code == 201
90
+ entry_id = store_response.json()["id"]
91
+
92
+ # Delete
93
+ response = client.delete(f"/api/memory/{entry_id}")
94
+ assert response.status_code == 204
95
+
96
+ # Verify deleted
97
+ get_response = client.get(f"/api/memory/{entry_id}")
98
+ assert get_response.status_code == 404
99
+
100
+ def test_query_memory(self, client: TestClient) -> None:
101
+ """Test POST /api/memory/query."""
102
+ # Store some entries first
103
+ for i in range(3):
104
+ payload = {
105
+ "memory_type": "short_term",
106
+ "content": {"index": i, "data": f"test_{i}"},
107
+ }
108
+ client.post("/api/memory/store", json=payload)
109
+
110
+ # Query
111
+ query_payload = {"query": "test", "limit": 10}
112
+ response = client.post("/api/memory/query", json=query_payload)
113
+
114
+ assert response.status_code == 200
115
+ data = response.json()
116
+ assert "entries" in data
117
+ assert "total_found" in data
118
+
119
+ def test_get_memory_stats(self, client: TestClient) -> None:
120
+ """Test GET /api/memory/stats/overview."""
121
+ response = client.get("/api/memory/stats/overview")
122
+
123
+ assert response.status_code == 200
124
+ data = response.json()
125
+
126
+ assert "short_term_count" in data
127
+ assert "working_count" in data
128
+ assert "long_term_count" in data
129
+ assert "shared_count" in data
130
+ assert "total_count" in data
131
+
132
+ def test_clear_memory_layer(self, client: TestClient) -> None:
133
+ """Test DELETE /api/memory/clear/{memory_type}."""
134
+ # Store entries
135
+ payload = {
136
+ "memory_type": "short_term",
137
+ "content": {"test": "data"},
138
+ }
139
+ client.post("/api/memory/store", json=payload)
140
+
141
+ # Clear
142
+ response = client.delete("/api/memory/clear/short_term")
143
+ assert response.status_code == 204
144
+
145
+ def test_consolidate_memory(self, client: TestClient) -> None:
146
+ """Test POST /api/memory/consolidate."""
147
+ # Store short-term entries
148
+ for i in range(3):
149
+ payload = {
150
+ "memory_type": "short_term",
151
+ "content": {"index": i},
152
+ }
153
+ client.post("/api/memory/store", json=payload)
154
+
155
+ # Consolidate
156
+ response = client.post("/api/memory/consolidate")
157
+
158
+ assert response.status_code == 200
159
+ data = response.json()
160
+ assert "consolidated_count" in data
backend/tests/test_api/test_tasks.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for tasks API routes."""
2
+
3
+ import pytest
4
+ from fastapi.testclient import TestClient
5
+
6
+
7
+ class TestTasksAPI:
8
+ """Test tasks API endpoints."""
9
+
10
+ def test_list_tasks(self, client: TestClient) -> None:
11
+ """Test GET /api/tasks/ returns task list."""
12
+ response = client.get("/api/tasks/")
13
+
14
+ assert response.status_code == 200
15
+ data = response.json()
16
+
17
+ assert "tasks" in data
18
+ assert "total" in data
19
+ assert "page" in data
20
+ assert "page_size" in data
21
+ assert isinstance(data["tasks"], list)
22
+
23
+ def test_list_tasks_pagination(self, client: TestClient) -> None:
24
+ """Test task list pagination."""
25
+ response = client.get("/api/tasks/?page=1&page_size=2")
26
+
27
+ assert response.status_code == 200
28
+ data = response.json()
29
+ assert data["page"] == 1
30
+ assert data["page_size"] == 2
31
+
32
+ def test_list_tasks_filter_by_difficulty(self, client: TestClient) -> None:
33
+ """Test filtering tasks by difficulty."""
34
+ response = client.get("/api/tasks/?difficulty=easy")
35
+
36
+ assert response.status_code == 200
37
+ data = response.json()
38
+ for task in data["tasks"]:
39
+ assert task["difficulty"] == "easy"
40
+
41
+ def test_list_tasks_filter_by_type(self, client: TestClient) -> None:
42
+ """Test filtering tasks by type."""
43
+ response = client.get("/api/tasks/?task_type=single_page")
44
+
45
+ assert response.status_code == 200
46
+ data = response.json()
47
+ for task in data["tasks"]:
48
+ assert task["task_type"] == "single_page"
49
+
50
+ def test_list_tasks_filter_by_tag(self, client: TestClient) -> None:
51
+ """Test filtering tasks by tag."""
52
+ response = client.get("/api/tasks/?tag=ecommerce")
53
+
54
+ assert response.status_code == 200
55
+ data = response.json()
56
+ for task in data["tasks"]:
57
+ assert "ecommerce" in task["tags"]
58
+
59
+ def test_get_task_by_id(self, client: TestClient) -> None:
60
+ """Test GET /api/tasks/{task_id}."""
61
+ response = client.get("/api/tasks/task_001")
62
+
63
+ assert response.status_code == 200
64
+ data = response.json()
65
+ assert data["id"] == "task_001"
66
+ assert "name" in data
67
+ assert "description" in data
68
+ assert "fields_to_extract" in data
69
+
70
+ def test_get_nonexistent_task(self, client: TestClient) -> None:
71
+ """Test GET /api/tasks/{task_id} for non-existent task."""
72
+ response = client.get("/api/tasks/nonexistent-task-id")
73
+ assert response.status_code == 404
74
+
75
+ def test_create_task(self, client: TestClient) -> None:
76
+ """Test POST /api/tasks/ creates a new task."""
77
+ import uuid
78
+ task_id = f"test_task_{uuid.uuid4().hex[:8]}"
79
+
80
+ payload = {
81
+ "id": task_id,
82
+ "name": "Test Task",
83
+ "description": "A test scraping task",
84
+ "task_type": "single_page",
85
+ "difficulty": "easy",
86
+ "target_url": "https://example.com/test",
87
+ "fields_to_extract": [
88
+ {
89
+ "name": "title",
90
+ "description": "Page title",
91
+ "field_type": "string",
92
+ "required": True,
93
+ }
94
+ ],
95
+ "success_criteria": {"min_accuracy": 0.8},
96
+ }
97
+
98
+ response = client.post("/api/tasks/", json=payload)
99
+
100
+ assert response.status_code == 201
101
+ data = response.json()
102
+ assert data["id"] == task_id
103
+ assert data["name"] == "Test Task"
104
+
105
+ def test_create_duplicate_task(self, client: TestClient) -> None:
106
+ """Test creating duplicate task returns conflict."""
107
+ payload = {
108
+ "id": "task_001", # Existing task ID
109
+ "name": "Duplicate Task",
110
+ "description": "Should conflict",
111
+ "task_type": "single_page",
112
+ "difficulty": "easy",
113
+ "fields_to_extract": [
114
+ {"name": "test", "description": "test"}
115
+ ],
116
+ "success_criteria": {},
117
+ }
118
+
119
+ response = client.post("/api/tasks/", json=payload)
120
+ assert response.status_code == 409
121
+
122
+ def test_get_task_types(self, client: TestClient) -> None:
123
+ """Test GET /api/tasks/types/ returns task types."""
124
+ response = client.get("/api/tasks/types/")
125
+
126
+ assert response.status_code == 200
127
+ data = response.json()
128
+ assert "task_types" in data
129
+ assert "difficulties" in data
130
+ assert "single_page" in data["task_types"]
131
+ assert "easy" in data["difficulties"]
backend/tests/test_core/test_episode.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for episode management."""
2
+
3
+ import pytest
4
+ from app.core.episode import Episode, EpisodeStep, EpisodeStatus, EpisodeManager
5
+
6
+
7
+ class TestEpisode:
8
+ """Test Episode class."""
9
+
10
+ def test_episode_creation(self) -> None:
11
+ """Test creating an episode."""
12
+ episode = Episode(
13
+ episode_id="ep_001",
14
+ task_id="task_001",
15
+ max_steps=10,
16
+ )
17
+
18
+ assert episode.episode_id == "ep_001"
19
+ assert episode.task_id == "task_001"
20
+ assert episode.max_steps == 10
21
+ assert episode.status == EpisodeStatus.PENDING
22
+ assert len(episode.steps) == 0
23
+
24
+ def test_episode_start(self) -> None:
25
+ """Test starting an episode."""
26
+ episode = Episode(episode_id="ep_002", task_id="task_002")
27
+ episode.start()
28
+
29
+ assert episode.status == EpisodeStatus.RUNNING
30
+ assert episode.started_at is not None
31
+
32
+ def test_episode_add_step(self) -> None:
33
+ """Test adding a step to episode."""
34
+ episode = Episode(episode_id="ep_003", task_id="task_003")
35
+ episode.start()
36
+
37
+ step = episode.add_step(
38
+ action_type="navigate",
39
+ action_params={"target": "/login"},
40
+ reward=0.5,
41
+ reward_breakdown={"progress": 0.5},
42
+ observation_summary={"url": "https://example.com"},
43
+ )
44
+
45
+ assert len(episode.steps) == 1
46
+ assert episode.steps[0].step_number == 1
47
+ assert episode.total_reward == 0.5
48
+
49
+ def test_episode_multiple_steps(self) -> None:
50
+ """Test adding multiple steps."""
51
+ episode = Episode(episode_id="ep_004", task_id="task_004")
52
+ episode.start()
53
+
54
+ rewards = [0.1, 0.2, 0.3, 0.4]
55
+ for i, reward in enumerate(rewards):
56
+ episode.add_step(
57
+ action_type="test",
58
+ action_params={"step": i},
59
+ reward=reward,
60
+ reward_breakdown={"base": reward},
61
+ observation_summary={"step": i},
62
+ )
63
+
64
+ assert len(episode.steps) == 4
65
+ assert episode.total_reward == pytest.approx(1.0)
66
+ assert episode.current_step == 4
67
+
68
+ def test_episode_completion(self) -> None:
69
+ """Test completing an episode."""
70
+ episode = Episode(episode_id="ep_005", task_id="task_005")
71
+ episode.start()
72
+ episode.complete(success=True)
73
+
74
+ assert episode.status == EpisodeStatus.COMPLETED
75
+ assert episode.ended_at is not None
76
+
77
+ def test_episode_failure(self) -> None:
78
+ """Test failing an episode."""
79
+ episode = Episode(episode_id="ep_006", task_id="task_006")
80
+ episode.start()
81
+ episode.fail(reason="Test failure")
82
+
83
+ assert episode.status == EpisodeStatus.FAILED
84
+ assert episode.failure_reason == "Test failure"
85
+
86
+ def test_episode_truncation(self) -> None:
87
+ """Test truncating an episode."""
88
+ episode = Episode(episode_id="ep_007", task_id="task_007", max_steps=5)
89
+ episode.start()
90
+
91
+ # Add steps up to max
92
+ for i in range(5):
93
+ episode.add_step(
94
+ action_type="test",
95
+ action_params={},
96
+ reward=0.1,
97
+ reward_breakdown={"base": 0.1},
98
+ observation_summary={},
99
+ )
100
+
101
+ episode.truncate()
102
+ assert episode.status == EpisodeStatus.TRUNCATED
103
+
104
+ def test_episode_is_terminal(self) -> None:
105
+ """Test terminal state check."""
106
+ episode = Episode(episode_id="ep_008", task_id="task_008")
107
+
108
+ assert not episode.is_terminal
109
+
110
+ episode.start()
111
+ assert not episode.is_terminal
112
+
113
+ episode.complete(success=True)
114
+ assert episode.is_terminal
115
+
116
+ def test_episode_duration(self) -> None:
117
+ """Test episode duration calculation."""
118
+ episode = Episode(episode_id="ep_009", task_id="task_009")
119
+ episode.start()
120
+
121
+ # Duration should be None before completion
122
+ import time
123
+
124
+ time.sleep(0.01) # Small delay
125
+ episode.complete(success=True)
126
+
127
+ assert episode.duration_seconds is not None
128
+ assert episode.duration_seconds >= 0
129
+
130
+ def test_episode_average_reward(self) -> None:
131
+ """Test average reward calculation."""
132
+ episode = Episode(episode_id="ep_010", task_id="task_010")
133
+ episode.start()
134
+
135
+ rewards = [0.2, 0.4, 0.6]
136
+ for i, reward in enumerate(rewards):
137
+ episode.add_step(
138
+ action_type="test",
139
+ action_params={},
140
+ reward=reward,
141
+ reward_breakdown={"base": reward},
142
+ observation_summary={},
143
+ )
144
+
145
+ assert episode.average_reward == pytest.approx(0.4)
146
+
147
+ def test_episode_summary(self) -> None:
148
+ """Test episode summary."""
149
+ episode = Episode(episode_id="ep_011", task_id="task_011")
150
+ episode.start()
151
+
152
+ summary = episode.get_summary()
153
+
154
+ assert summary["episode_id"] == "ep_011"
155
+ assert summary["task_id"] == "task_011"
156
+ assert "status" in summary
157
+ assert "steps" in summary
158
+
159
+ def test_episode_cancel(self) -> None:
160
+ """Test episode cancellation."""
161
+ episode = Episode(episode_id="ep_012", task_id="task_012")
162
+ episode.start()
163
+ episode.cancel()
164
+
165
+ assert episode.status == EpisodeStatus.CANCELLED
166
+ assert episode.is_terminal
167
+
168
+ def test_episode_get_action_sequence(self) -> None:
169
+ """Test getting action sequence."""
170
+ episode = Episode(episode_id="ep_013", task_id="task_013")
171
+ episode.start()
172
+
173
+ episode.add_step("navigate", {}, 0.1, {}, {})
174
+ episode.add_step("click", {}, 0.2, {}, {})
175
+ episode.add_step("extract", {}, 0.3, {}, {})
176
+
177
+ actions = episode.get_action_sequence()
178
+ assert actions == ["navigate", "click", "extract"]
179
+
180
+ def test_episode_get_reward_history(self) -> None:
181
+ """Test getting reward history."""
182
+ episode = Episode(episode_id="ep_014", task_id="task_014")
183
+ episode.start()
184
+
185
+ episode.add_step("a", {}, 0.1, {}, {})
186
+ episode.add_step("b", {}, 0.2, {}, {})
187
+ episode.add_step("c", {}, 0.3, {}, {})
188
+
189
+ rewards = episode.get_reward_history()
190
+ assert rewards == [0.1, 0.2, 0.3]
191
+
192
+
193
+ class TestEpisodeStep:
194
+ """Test EpisodeStep class."""
195
+
196
+ def test_step_creation(self) -> None:
197
+ """Test creating an episode step."""
198
+ from datetime import datetime, timezone
199
+
200
+ step = EpisodeStep(
201
+ step_number=1,
202
+ timestamp=datetime.now(timezone.utc).isoformat(),
203
+ action_type="click",
204
+ action_params={"selector": "#btn"},
205
+ reward=0.75,
206
+ reward_breakdown={"progress": 0.75},
207
+ observation_summary={"url": "https://example.com", "title": "Test"},
208
+ )
209
+
210
+ assert step.step_number == 1
211
+ assert step.action_type == "click"
212
+ assert step.action_params["selector"] == "#btn"
213
+ assert step.reward == 0.75
214
+
215
+ def test_step_with_error(self) -> None:
216
+ """Test step with error."""
217
+ from datetime import datetime, timezone
218
+
219
+ step = EpisodeStep(
220
+ step_number=1,
221
+ timestamp=datetime.now(timezone.utc).isoformat(),
222
+ action_type="click",
223
+ action_params={},
224
+ reward=-0.5,
225
+ reward_breakdown={"error": -0.5},
226
+ observation_summary={},
227
+ error="Element not found",
228
+ duration_ms=150.0,
229
+ )
230
+
231
+ assert step.error == "Element not found"
232
+ assert step.duration_ms == 150.0
233
+
234
+ def test_step_with_reasoning(self) -> None:
235
+ """Test step with action reasoning."""
236
+ from datetime import datetime, timezone
237
+
238
+ step = EpisodeStep(
239
+ step_number=1,
240
+ timestamp=datetime.now(timezone.utc).isoformat(),
241
+ action_type="extract",
242
+ action_params={"field": "price"},
243
+ action_reasoning="Extracting price from product page",
244
+ reward=0.5,
245
+ reward_breakdown={"extraction": 0.5},
246
+ observation_summary={},
247
+ )
248
+
249
+ assert step.action_reasoning == "Extracting price from product page"
250
+
251
+
252
+ class TestEpisodeManager:
253
+ """Test EpisodeManager class."""
254
+
255
+ def test_manager_create_episode(self) -> None:
256
+ """Test creating episode via manager."""
257
+ manager = EpisodeManager()
258
+ episode = manager.create_episode("ep_100", "task_100")
259
+
260
+ assert episode.episode_id == "ep_100"
261
+ assert episode.task_id == "task_100"
262
+
263
+ def test_manager_get_episode(self) -> None:
264
+ """Test getting episode from manager."""
265
+ manager = EpisodeManager()
266
+ manager.create_episode("ep_101", "task_101")
267
+
268
+ episode = manager.get_episode("ep_101")
269
+ assert episode is not None
270
+ assert episode.episode_id == "ep_101"
271
+
272
+ def test_manager_get_nonexistent(self) -> None:
273
+ """Test getting non-existent episode."""
274
+ manager = EpisodeManager()
275
+ episode = manager.get_episode("nonexistent")
276
+ assert episode is None
277
+
278
+ def test_manager_remove_episode(self) -> None:
279
+ """Test removing episode from manager."""
280
+ manager = EpisodeManager()
281
+ manager.create_episode("ep_102", "task_102")
282
+
283
+ removed = manager.remove_episode("ep_102")
284
+ assert removed is True
285
+
286
+ episode = manager.get_episode("ep_102")
287
+ assert episode is None
288
+
289
+ def test_manager_list_episodes(self) -> None:
290
+ """Test listing episodes."""
291
+ manager = EpisodeManager()
292
+ manager.create_episode("ep_103", "task_103")
293
+ manager.create_episode("ep_104", "task_104")
294
+ manager.create_episode("ep_105", "task_105")
295
+
296
+ episodes = manager.list_episodes()
297
+ assert len(episodes) == 3
298
+
299
+ def test_manager_list_episodes_by_status(self) -> None:
300
+ """Test listing episodes by status."""
301
+ manager = EpisodeManager()
302
+
303
+ ep1 = manager.create_episode("ep_106", "task_106")
304
+ ep2 = manager.create_episode("ep_107", "task_107")
305
+ ep3 = manager.create_episode("ep_108", "task_108")
306
+
307
+ ep1.start()
308
+ ep2.start()
309
+ ep2.complete(success=True)
310
+
311
+ running = manager.list_episodes(status=EpisodeStatus.RUNNING)
312
+ assert len(running) == 1
313
+ assert running[0].episode_id == "ep_106"
314
+
315
+ completed = manager.list_episodes(status=EpisodeStatus.COMPLETED)
316
+ assert len(completed) == 1
317
+ assert completed[0].episode_id == "ep_107"
318
+
319
+ def test_manager_list_episodes_by_task(self) -> None:
320
+ """Test listing episodes by task ID."""
321
+ manager = EpisodeManager()
322
+ manager.create_episode("ep_109", "task_A")
323
+ manager.create_episode("ep_110", "task_A")
324
+ manager.create_episode("ep_111", "task_B")
325
+
326
+ task_a_episodes = manager.list_episodes(task_id="task_A")
327
+ assert len(task_a_episodes) == 2
328
+
329
+ task_b_episodes = manager.list_episodes(task_id="task_B")
330
+ assert len(task_b_episodes) == 1
backend/tests/test_models/test_base_simple.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple tests to verify the base model structures."""
2
+
3
+ import pytest
4
+ from app.models.providers.base import (
5
+ TokenUsage,
6
+ CompletionResponse,
7
+ ModelInfo,
8
+ ProviderError
9
+ )
10
+
11
+
12
+ def test_token_usage_creation():
13
+ """Test TokenUsage creation and addition."""
14
+ usage1 = TokenUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30)
15
+ usage2 = TokenUsage(prompt_tokens=5, completion_tokens=10, total_tokens=15)
16
+
17
+ combined = usage1 + usage2
18
+ assert combined.prompt_tokens == 15
19
+ assert combined.completion_tokens == 30
20
+ assert combined.total_tokens == 45
21
+
22
+
23
+ def test_completion_response_creation():
24
+ """Test CompletionResponse creation."""
25
+ usage = TokenUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30)
26
+
27
+ response = CompletionResponse(
28
+ content="Hello world",
29
+ model="test-model",
30
+ provider="test-provider",
31
+ usage=usage,
32
+ finish_reason="stop",
33
+ cost=0.001
34
+ )
35
+
36
+ assert response.content == "Hello world"
37
+ assert response.model == "test-model"
38
+ assert response.provider == "test-provider"
39
+ assert response.usage.total_tokens == 30
40
+ assert response.cost == 0.001
41
+
42
+
43
+ def test_model_info_creation():
44
+ """Test ModelInfo creation."""
45
+ info = ModelInfo(
46
+ id="test-model",
47
+ name="Test Model",
48
+ provider="test",
49
+ context_window=4096,
50
+ max_output_tokens=1000,
51
+ cost_per_1k_input=0.001,
52
+ cost_per_1k_output=0.002
53
+ )
54
+
55
+ assert info.id == "test-model"
56
+ assert info.context_window == 4096
57
+ assert info.cost_per_million_input == 1.0
58
+ assert info.cost_per_million_output == 2.0
59
+
60
+
61
+ def test_provider_error():
62
+ """Test ProviderError creation."""
63
+ error = ProviderError("Test error", "test-provider", 500)
64
+
65
+ assert error.message == "Test error"
66
+ assert error.provider == "test-provider"
67
+ assert error.status_code == 500
68
+ assert str(error) == "[test-provider] Test error"