Cyber-Machine commited on
Commit
aea0016
·
verified ·
0 Parent(s):

init: WorkFlowArena

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """WorkflowArena package exports."""
8
+
9
+ from workflow_arena.client import WorkflowArenaEnv
10
+ from workflow_arena.generator import generate_episode
11
+ from workflow_arena.models import (
12
+ DifficultyPreset,
13
+ EpisodeConfig,
14
+ TaskStatus,
15
+ WorkflowActionType,
16
+ WorkflowArenaAction,
17
+ WorkflowArenaObservation,
18
+ )
19
+ from workflow_arena.presets import PRESET_CONFIGS, get_preset_config
20
+ from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
21
+
22
+ __all__ = [
23
+ "DifficultyPreset",
24
+ "EpisodeConfig",
25
+ "PRESET_CONFIGS",
26
+ "TaskStatus",
27
+ "WorkflowActionType",
28
+ "WorkflowArenaAction",
29
+ "WorkflowArenaEnv",
30
+ "WorkflowArenaEnvironment",
31
+ "WorkflowArenaObservation",
32
+ "generate_episode",
33
+ "get_preset_config",
34
+ ]
client.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """WorkflowArena client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from workflow_arena.models import WorkflowArenaAction, WorkflowArenaObservation
16
+
17
+
18
+ class WorkflowArenaEnv(
19
+ EnvClient[WorkflowArenaAction, WorkflowArenaObservation, State]
20
+ ):
21
+ """Typed client for the WorkflowArena server."""
22
+
23
+ def _step_payload(self, action: WorkflowArenaAction) -> Dict:
24
+ """Convert a typed action into the JSON payload expected by the server."""
25
+
26
+ return action.model_dump(mode="json")
27
+
28
+ def _parse_result(self, payload: Dict) -> StepResult[WorkflowArenaObservation]:
29
+ """Parse a server response into a typed observation wrapper."""
30
+
31
+ obs_data = payload.get("observation", {})
32
+ observation = WorkflowArenaObservation.model_validate(
33
+ {
34
+ **obs_data,
35
+ "done": payload.get("done", False),
36
+ "reward": payload.get("reward"),
37
+ }
38
+ )
39
+
40
+ return StepResult(
41
+ observation=observation,
42
+ reward=payload.get("reward"),
43
+ done=payload.get("done", False),
44
+ )
45
+
46
+ def _parse_state(self, payload: Dict) -> State:
47
+ """Parse server state response into the generic OpenEnv state type."""
48
+
49
+ return State(
50
+ episode_id=payload.get("episode_id"),
51
+ step_count=payload.get("step_count", 0),
52
+ )
generator.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Seeded workflow DAG generator and derived static metrics for WorkflowArena."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+
13
+ from workflow_arena.models import (
14
+ EpisodeConfig,
15
+ TaskStatus,
16
+ WorkflowEnvStateSnapshot,
17
+ WorkflowEpisodeSpec,
18
+ WorkflowTaskSpec,
19
+ )
20
+ from workflow_arena.presets import get_preset_config
21
+
22
+
23
+ def _task_id(index: int) -> str:
24
+ return f"task_{index:02d}"
25
+
26
+
27
+ def _compute_earliest_start(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
28
+ task = task_map[task_id]
29
+ if not task.dependencies:
30
+ return 0
31
+ return max(
32
+ _compute_earliest_start(task_map, dep_id) + task_map[dep_id].duration
33
+ for dep_id in task.dependencies
34
+ )
35
+
36
+
37
+ def _compute_critical_path(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
38
+ task = task_map[task_id]
39
+ if not task.dependents:
40
+ return task.duration
41
+ return task.duration + max(
42
+ _compute_critical_path(task_map, child_id) for child_id in task.dependents
43
+ )
44
+
45
+
46
+ def _compute_downstream_count(
47
+ task_map: dict[str, WorkflowTaskSpec], task_id: str, seen: set[str] | None = None
48
+ ) -> int:
49
+ task = task_map[task_id]
50
+ local_seen = set() if seen is None else seen
51
+ count = 0
52
+ for child_id in task.dependents:
53
+ if child_id in local_seen:
54
+ continue
55
+ local_seen.add(child_id)
56
+ count += 1 + _compute_downstream_count(task_map, child_id, local_seen)
57
+ return count
58
+
59
+
60
+ def _estimate_deadline(
61
+ task: WorkflowTaskSpec,
62
+ workflow_critical_path: int,
63
+ rng: random.Random,
64
+ tightness: float,
65
+ ) -> int:
66
+ slack_allowance = max(1, int(round((workflow_critical_path - task.earliest_start) * (1.15 - tightness))))
67
+ jitter = rng.randint(0, max(1, task.duration // 2))
68
+ return task.earliest_start + task.duration + slack_allowance + jitter
69
+
70
+
71
+ def generate_episode(
72
+ config: EpisodeConfig,
73
+ ) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
74
+ """Generate a deterministic workflow episode from a preset and seed."""
75
+
76
+ preset_config = get_preset_config(config.preset)
77
+ worker_count = config.worker_count or preset_config.worker_count
78
+ resolved_config = config.model_copy(update={"worker_count": worker_count})
79
+ rng = random.Random(resolved_config.seed)
80
+ task_count = rng.randint(preset_config.min_tasks, preset_config.max_tasks)
81
+
82
+ dependency_map: dict[str, list[str]] = {}
83
+ dependent_map: dict[str, list[str]] = {}
84
+ task_ids = [_task_id(index + 1) for index in range(task_count)]
85
+
86
+ for index, task_id in enumerate(task_ids):
87
+ candidates = task_ids[:index]
88
+ dependencies: list[str] = []
89
+ if candidates:
90
+ for candidate in candidates:
91
+ if rng.random() < preset_config.edge_probability:
92
+ dependencies.append(candidate)
93
+ if not dependencies and index > 0 and rng.random() < 0.6:
94
+ dependencies.append(rng.choice(candidates))
95
+ dependency_map[task_id] = sorted(set(dependencies), key=task_ids.index)
96
+ dependent_map[task_id] = []
97
+
98
+ for task_id, dependencies in dependency_map.items():
99
+ for dependency in dependencies:
100
+ dependent_map[dependency].append(task_id)
101
+
102
+ tasks = [
103
+ WorkflowTaskSpec(
104
+ task_id=task_id,
105
+ duration=rng.randint(preset_config.duration_min, preset_config.duration_max),
106
+ priority=rng.randint(preset_config.priority_min, preset_config.priority_max),
107
+ dependencies=dependency_map[task_id],
108
+ dependents=sorted(dependent_map[task_id], key=task_ids.index),
109
+ deadline=None,
110
+ )
111
+ for task_id in task_ids
112
+ ]
113
+
114
+ task_map = {task.task_id: task for task in tasks}
115
+
116
+ workflow_critical_path = 0
117
+ for task in tasks:
118
+ task.earliest_start = _compute_earliest_start(task_map, task.task_id)
119
+ task.critical_path_length = _compute_critical_path(task_map, task.task_id)
120
+ task.downstream_count = _compute_downstream_count(task_map, task.task_id)
121
+ workflow_critical_path = max(
122
+ workflow_critical_path, task.earliest_start + task.duration
123
+ )
124
+
125
+ workflow_critical_path = max(
126
+ workflow_critical_path,
127
+ max(task.critical_path_length for task in tasks),
128
+ )
129
+
130
+ max_downstream = max(task.downstream_count for task in tasks) if tasks else 1
131
+ max_critical_path = max(task.critical_path_length for task in tasks) if tasks else 1
132
+
133
+ for task in tasks:
134
+ latest_start = max(
135
+ task.earliest_start, workflow_critical_path - task.critical_path_length
136
+ )
137
+ task.slack = max(0, latest_start - task.earliest_start)
138
+ task.criticality = round(
139
+ 0.7 * (task.critical_path_length / max_critical_path)
140
+ + 0.3 * (task.downstream_count / max(1, max_downstream)),
141
+ 4,
142
+ )
143
+ task.deadline = _estimate_deadline(
144
+ task=task,
145
+ workflow_critical_path=workflow_critical_path,
146
+ rng=rng,
147
+ tightness=preset_config.deadline_tightness,
148
+ )
149
+
150
+ episode = WorkflowEpisodeSpec(
151
+ config=resolved_config,
152
+ preset_config=preset_config,
153
+ tasks=tasks,
154
+ )
155
+
156
+ ready_task_ids = [task.task_id for task in tasks if not task.dependencies]
157
+ blocked_task_ids = [task.task_id for task in tasks if task.dependencies]
158
+
159
+ state = WorkflowEnvStateSnapshot(
160
+ episode_id=f"seed-{resolved_config.seed}",
161
+ current_time=0,
162
+ task_statuses={
163
+ task.task_id: (
164
+ TaskStatus.READY if not task.dependencies else TaskStatus.BLOCKED
165
+ )
166
+ for task in tasks
167
+ },
168
+ running_task_ids=[],
169
+ completed_task_ids=[],
170
+ ready_task_ids=ready_task_ids,
171
+ blocked_task_ids=blocked_task_ids,
172
+ task_start_times={},
173
+ task_end_times={},
174
+ task_remaining_dependencies={
175
+ task.task_id: len(task.dependencies) for task in tasks
176
+ },
177
+ task_assigned_finish_times={},
178
+ task_attempt_counts={task.task_id: 0 for task in tasks},
179
+ cumulative_busy_time=0,
180
+ time_budget=None,
181
+ degraded_workers=0,
182
+ active_worker_outage_until=None,
183
+ recent_failure_events=[],
184
+ )
185
+ return episode, state
models.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Typed models for WorkflowArena.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from enum import Enum
14
+
15
+ from openenv.core.env_server.types import Action, Observation
16
+ from pydantic import BaseModel, Field
17
+
18
+
19
+ class TaskStatus(str, Enum):
20
+ """Allowed lifecycle states for a workflow task."""
21
+
22
+ BLOCKED = "blocked"
23
+ READY = "ready"
24
+ RUNNING = "running"
25
+ COMPLETED = "completed"
26
+
27
+
28
+ class DifficultyPreset(str, Enum):
29
+ """Initial task presets required by the hackathon."""
30
+
31
+ EASY = "easy"
32
+ MEDIUM = "medium"
33
+ HARD = "hard"
34
+
35
+
36
+ class WorkflowActionType(str, Enum):
37
+ """Explicit action space for the scheduler agent."""
38
+
39
+ DISPATCH = "dispatch"
40
+ WAIT = "wait"
41
+
42
+
43
+ class RewardBreakdown(BaseModel):
44
+ """Named reward channels for shaped feedback."""
45
+
46
+ completion_reward: float = Field(
47
+ default=0.0, description="Reward for completing tasks."
48
+ )
49
+ utilization_reward: float = Field(
50
+ default=0.0, description="Reward for keeping workers busy."
51
+ )
52
+ deadline_reward: float = Field(
53
+ default=0.0, description="Reward or penalty tied to deadlines."
54
+ )
55
+ criticality_reward: float = Field(
56
+ default=0.0,
57
+ description="Reward for prioritizing critical-path work appropriately.",
58
+ )
59
+ idle_penalty: float = Field(
60
+ default=0.0, description="Penalty for leaving workers idle."
61
+ )
62
+ invalid_action_penalty: float = Field(
63
+ default=0.0,
64
+ description="Penalty for malformed or infeasible actions.",
65
+ )
66
+ terminal_makespan_score: float = Field(
67
+ default=0.0,
68
+ description="Terminal score based on final schedule quality.",
69
+ )
70
+ unfinished_task_penalty: float = Field(
71
+ default=0.0,
72
+ description="Terminal penalty for unfinished work at episode end.",
73
+ )
74
+
75
+
76
+ class FailureEventType(str, Enum):
77
+ """Failure events surfaced to agents and the UI."""
78
+
79
+ WORKER_OUTAGE_START = "worker_outage_start"
80
+ WORKER_OUTAGE_END = "worker_outage_end"
81
+ TASK_RETRY_FAILURE = "task_retry_failure"
82
+
83
+
84
+ class WorkflowFailureEvent(BaseModel):
85
+ """Structured failure event emitted by the environment."""
86
+
87
+ event_type: FailureEventType = Field(..., description="Failure category.")
88
+ time: int = Field(..., ge=0, description="Simulated time when the event was observed.")
89
+ task_id: str | None = Field(default=None, description="Task affected by the event, if any.")
90
+ worker_delta: int = Field(default=0, description="Net temporary change in usable workers.")
91
+ duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.")
92
+ detail: str = Field(default="", description="Short human-readable summary.")
93
+
94
+
95
+ class WorkflowTaskView(BaseModel):
96
+ """Compact task payload used in observations and the future UI."""
97
+
98
+ task_id: str = Field(..., description="Stable task identifier.")
99
+ status: TaskStatus = Field(..., description="Current task lifecycle state.")
100
+ duration: int = Field(
101
+ ..., ge=1, description="Task runtime in simulated time units."
102
+ )
103
+ priority: int = Field(..., ge=0, description="Priority weight for the task.")
104
+ dependencies: list[str] = Field(
105
+ default_factory=list,
106
+ description="Upstream task ids that must complete first.",
107
+ )
108
+ deadline: int | None = Field(
109
+ default=None,
110
+ ge=0,
111
+ description="Optional deadline in simulated time units.",
112
+ )
113
+ criticality: float | None = Field(
114
+ default=None,
115
+ description="Derived importance score from the DAG structure.",
116
+ )
117
+ slack: float | None = Field(
118
+ default=None,
119
+ description="Derived slack estimate for scheduling decisions.",
120
+ )
121
+ downstream_count: int = Field(
122
+ default=0,
123
+ ge=0,
124
+ description="Count of downstream dependents reachable from this task.",
125
+ )
126
+ start_time: int | None = Field(
127
+ default=None,
128
+ ge=0,
129
+ description="Simulated start time if the task is running or completed.",
130
+ )
131
+ end_time: int | None = Field(
132
+ default=None,
133
+ ge=0,
134
+ description="Simulated end time if the task is completed or scheduled to finish.",
135
+ )
136
+ attempt_count: int = Field(
137
+ default=0,
138
+ ge=0,
139
+ description="Number of retry attempts already consumed by this task.",
140
+ )
141
+
142
+
143
+ class WorkflowTaskSpec(BaseModel):
144
+ """Static task specification generated at episode reset."""
145
+
146
+ task_id: str = Field(..., description="Stable task identifier.")
147
+ duration: int = Field(..., ge=1, description="Task runtime in simulated time units.")
148
+ priority: int = Field(..., ge=0, description="Priority weight for the task.")
149
+ dependencies: list[str] = Field(
150
+ default_factory=list,
151
+ description="Upstream task ids that must complete first.",
152
+ )
153
+ dependents: list[str] = Field(
154
+ default_factory=list,
155
+ description="Downstream task ids that depend on this task.",
156
+ )
157
+ deadline: int | None = Field(
158
+ default=None,
159
+ ge=0,
160
+ description="Optional deadline in simulated time units.",
161
+ )
162
+ downstream_count: int = Field(
163
+ default=0,
164
+ ge=0,
165
+ description="Number of downstream tasks reachable from this node.",
166
+ )
167
+ critical_path_length: int = Field(
168
+ default=0,
169
+ ge=0,
170
+ description="Duration-weighted path length from this task to a sink.",
171
+ )
172
+ earliest_start: int = Field(
173
+ default=0,
174
+ ge=0,
175
+ description="Earliest feasible start time under dependency constraints.",
176
+ )
177
+ slack: int = Field(
178
+ default=0,
179
+ ge=0,
180
+ description="Scheduling slack measured in simulated time units.",
181
+ )
182
+ criticality: float = Field(
183
+ default=0.0,
184
+ description="Normalized importance score derived from critical path and downstream impact.",
185
+ )
186
+
187
+
188
+ class ProgressSummary(BaseModel):
189
+ """Counts by task lifecycle state."""
190
+
191
+ total: int = Field(default=0, ge=0)
192
+ blocked: int = Field(default=0, ge=0)
193
+ ready: int = Field(default=0, ge=0)
194
+ running: int = Field(default=0, ge=0)
195
+ completed: int = Field(default=0, ge=0)
196
+
197
+
198
+ class EpisodeConfig(BaseModel):
199
+ """Reset-time knobs that define the episode."""
200
+
201
+ preset: DifficultyPreset = Field(
202
+ default=DifficultyPreset.EASY,
203
+ description="Difficulty preset for the episode generator.",
204
+ )
205
+ seed: int = Field(
206
+ default=0, description="Seed for deterministic episode generation."
207
+ )
208
+ worker_count: int = Field(
209
+ default=2,
210
+ ge=1,
211
+ description="Number of identical workers available to the scheduler.",
212
+ )
213
+
214
+
215
+ class GraderTarget(BaseModel):
216
+ """High-level target bands for each preset's grader."""
217
+
218
+ description: str = Field(..., description="What good performance means for the preset.")
219
+ score_band_hint: str = Field(..., description="Human-readable interpretation of scores.")
220
+
221
+
222
+ class DifficultyPresetConfig(BaseModel):
223
+ """Concrete generator knobs for a preset."""
224
+
225
+ preset: DifficultyPreset = Field(..., description="Preset identifier.")
226
+ min_tasks: int = Field(..., ge=2)
227
+ max_tasks: int = Field(..., ge=2)
228
+ edge_probability: float = Field(..., ge=0.0, le=1.0)
229
+ duration_min: int = Field(..., ge=1)
230
+ duration_max: int = Field(..., ge=1)
231
+ priority_min: int = Field(..., ge=0)
232
+ priority_max: int = Field(..., ge=0)
233
+ worker_count: int = Field(..., ge=1)
234
+ deadline_tightness: float = Field(
235
+ ...,
236
+ ge=0.0,
237
+ description="Larger values mean tighter deadlines.",
238
+ )
239
+ time_budget_multiplier: float | None = Field(
240
+ default=None,
241
+ gt=0.0,
242
+ description="Optional multiplier over the theoretical lower-bound makespan.",
243
+ )
244
+ worker_outage_rate: float = Field(
245
+ default=0.0,
246
+ ge=0.0,
247
+ le=1.0,
248
+ description="Chance of a hard-mode worker outage being sampled on a wait transition.",
249
+ )
250
+ worker_outage_duration_min: int = Field(
251
+ default=0,
252
+ ge=0,
253
+ description="Minimum outage duration in simulated time units.",
254
+ )
255
+ worker_outage_duration_max: int = Field(
256
+ default=0,
257
+ ge=0,
258
+ description="Maximum outage duration in simulated time units.",
259
+ )
260
+ task_retry_failure_rate: float = Field(
261
+ default=0.0,
262
+ ge=0.0,
263
+ le=1.0,
264
+ description="Chance that a hard-mode task completion becomes a retry failure.",
265
+ )
266
+ max_task_retries: int = Field(
267
+ default=0,
268
+ ge=0,
269
+ description="Maximum number of retry failures a task may suffer before it must complete.",
270
+ )
271
+ grader_target: GraderTarget = Field(
272
+ ...,
273
+ description="Preset-specific grader interpretation.",
274
+ )
275
+
276
+
277
+ class WorkflowEpisodeSpec(BaseModel):
278
+ """Static episode description produced by the generator."""
279
+
280
+ config: EpisodeConfig = Field(..., description="Reset-time configuration.")
281
+ preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.")
282
+ tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.")
283
+
284
+
285
+ class WorkflowEnvStateSnapshot(BaseModel):
286
+ """Serializable environment state for the current episode."""
287
+
288
+ episode_id: str = Field(..., description="Stable current episode identifier.")
289
+ current_time: int = Field(default=0, ge=0, description="Current simulated time.")
290
+ task_statuses: dict[str, TaskStatus] = Field(
291
+ default_factory=dict,
292
+ description="Current task status by task id.",
293
+ )
294
+ running_task_ids: list[str] = Field(
295
+ default_factory=list,
296
+ description="Tasks currently consuming workers.",
297
+ )
298
+ completed_task_ids: list[str] = Field(
299
+ default_factory=list,
300
+ description="Tasks that have completed.",
301
+ )
302
+ ready_task_ids: list[str] = Field(
303
+ default_factory=list,
304
+ description="Tasks currently ready for dispatch.",
305
+ )
306
+ blocked_task_ids: list[str] = Field(
307
+ default_factory=list,
308
+ description="Tasks still blocked on dependencies.",
309
+ )
310
+ task_start_times: dict[str, int] = Field(
311
+ default_factory=dict,
312
+ description="Simulated start time by task id.",
313
+ )
314
+ task_end_times: dict[str, int] = Field(
315
+ default_factory=dict,
316
+ description="Simulated completion time by task id.",
317
+ )
318
+ task_remaining_dependencies: dict[str, int] = Field(
319
+ default_factory=dict,
320
+ description="Remaining unfinished prerequisites by task id.",
321
+ )
322
+ task_assigned_finish_times: dict[str, int] = Field(
323
+ default_factory=dict,
324
+ description="Predicted completion times for currently running tasks.",
325
+ )
326
+ task_attempt_counts: dict[str, int] = Field(
327
+ default_factory=dict,
328
+ description="Retry attempts consumed by each task.",
329
+ )
330
+ cumulative_busy_time: int = Field(
331
+ default=0,
332
+ ge=0,
333
+ description="Aggregate worker busy time accrued so far.",
334
+ )
335
+ time_budget: int | None = Field(
336
+ default=None,
337
+ ge=0,
338
+ description="Optional terminal time budget for the episode.",
339
+ )
340
+ degraded_workers: int = Field(
341
+ default=0,
342
+ ge=0,
343
+ description="Workers temporarily removed from usable capacity.",
344
+ )
345
+ active_worker_outage_until: int | None = Field(
346
+ default=None,
347
+ ge=0,
348
+ description="Time when the current worker outage expires, if any.",
349
+ )
350
+ recent_failure_events: list[WorkflowFailureEvent] = Field(
351
+ default_factory=list,
352
+ description="Failure events generated on the latest transition.",
353
+ )
354
+
355
+
356
+ class SuccessMetrics(BaseModel):
357
+ """Primary quality metrics used for grading and demos."""
358
+
359
+ makespan: int | None = Field(
360
+ default=None, description="Total simulated completion time."
361
+ )
362
+ worker_utilization: float | None = Field(
363
+ default=None,
364
+ description="Fraction of available worker time that was used.",
365
+ )
366
+ deadline_miss_count: int = Field(
367
+ default=0, ge=0, description="Missed task deadlines."
368
+ )
369
+ unfinished_task_count: int = Field(
370
+ default=0, ge=0, description="Tasks left incomplete at terminal time."
371
+ )
372
+ weighted_priority_completion: float | None = Field(
373
+ default=None,
374
+ description="Priority-weighted on-time completion score.",
375
+ )
376
+ benchmark_score: float | None = Field(
377
+ default=None,
378
+ description="Deterministic terminal benchmark score in the 0.0-1.0 range.",
379
+ )
380
+
381
+
382
+ class WorkflowArenaAction(Action):
383
+ """Strict action space for the workflow scheduler."""
384
+
385
+ action_type: WorkflowActionType = Field(
386
+ ...,
387
+ description="Dispatch ready tasks or wait for the next completion event.",
388
+ )
389
+ task_ids: list[str] = Field(
390
+ default_factory=list,
391
+ description="Task ids to dispatch. Must be empty for wait().",
392
+ )
393
+
394
+
395
+ class WorkflowArenaObservation(Observation):
396
+ """Compact, typed observation contract for WorkflowArena."""
397
+
398
+ instruction: str = Field(
399
+ default=(
400
+ "Schedule dependency-constrained workflow tasks on limited workers using "
401
+ "dispatch(task_ids=[...]) or wait()."
402
+ ),
403
+ description="Short prompt shown to inference agents.",
404
+ )
405
+ config: EpisodeConfig = Field(
406
+ default_factory=EpisodeConfig,
407
+ description="Episode generation settings.",
408
+ )
409
+ current_time: int = Field(default=0, ge=0, description="Current simulated time.")
410
+ total_workers: int = Field(default=2, ge=1, description="Total identical workers.")
411
+ effective_workers: int = Field(
412
+ default=2,
413
+ ge=0,
414
+ description="Usable workers after temporary degradation is applied.",
415
+ )
416
+ degraded_workers: int = Field(
417
+ default=0,
418
+ ge=0,
419
+ description="Workers currently unavailable due to outages.",
420
+ )
421
+ free_workers: int = Field(default=2, ge=0, description="Currently idle workers.")
422
+ time_budget: int | None = Field(
423
+ default=None,
424
+ ge=0,
425
+ description="Optional terminal time budget for the current episode.",
426
+ )
427
+ time_remaining: int | None = Field(
428
+ default=None,
429
+ description="Remaining time until the episode budget expires, if budgeted.",
430
+ )
431
+ progress: ProgressSummary = Field(
432
+ default_factory=ProgressSummary,
433
+ description="Task counts by lifecycle state.",
434
+ )
435
+ ready_tasks: list[WorkflowTaskView] = Field(
436
+ default_factory=list,
437
+ description="Ready tasks eligible for dispatch.",
438
+ )
439
+ running_tasks: list[WorkflowTaskView] = Field(
440
+ default_factory=list,
441
+ description="Tasks currently consuming workers.",
442
+ )
443
+ completed_tasks: list[WorkflowTaskView] = Field(
444
+ default_factory=list,
445
+ description="Tasks already completed.",
446
+ )
447
+ blocked_tasks: list[WorkflowTaskView] = Field(
448
+ default_factory=list,
449
+ description="Tasks still waiting on dependencies.",
450
+ )
451
+ last_reward_breakdown: RewardBreakdown = Field(
452
+ default_factory=RewardBreakdown,
453
+ description="Per-step reward channel breakdown.",
454
+ )
455
+ cumulative_reward: float = Field(default=0.0, description="Running total reward.")
456
+ success_metrics: SuccessMetrics = Field(
457
+ default_factory=SuccessMetrics,
458
+ description="Primary schedule quality metrics.",
459
+ )
460
+ note: str | None = Field(
461
+ default=None,
462
+ description="Short environment note about the latest transition.",
463
+ )
464
+ validation_error: str | None = Field(
465
+ default=None,
466
+ description="Explicit invalid-action explanation when the previous action failed.",
467
+ )
468
+ termination_reason: str | None = Field(
469
+ default=None,
470
+ description="Terminal reason when the episode ended unsuccessfully.",
471
+ )
472
+ benchmark_score: float | None = Field(
473
+ default=None,
474
+ description="Top-level bounded benchmark score for easier client access.",
475
+ )
476
+ recent_failure_events: list[WorkflowFailureEvent] = Field(
477
+ default_factory=list,
478
+ description="Failure events generated on the latest accepted transition.",
479
+ )
480
+ received_action: dict[str, object] | None = Field(
481
+ default=None,
482
+ description="Last action accepted by the server for logging and prompting.",
483
+ )
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: workflow_arena
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
presets.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Difficulty presets for WorkflowArena."""
8
+
9
+ from __future__ import annotations
10
+
11
+ from workflow_arena.models import DifficultyPreset, DifficultyPresetConfig, GraderTarget
12
+
13
+
14
+ PRESET_CONFIGS: dict[DifficultyPreset, DifficultyPresetConfig] = {
15
+ DifficultyPreset.EASY: DifficultyPresetConfig(
16
+ preset=DifficultyPreset.EASY,
17
+ min_tasks=8,
18
+ max_tasks=12,
19
+ edge_probability=0.14,
20
+ duration_min=1,
21
+ duration_max=4,
22
+ priority_min=1,
23
+ priority_max=4,
24
+ worker_count=3,
25
+ deadline_tightness=0.22,
26
+ time_budget_multiplier=None,
27
+ worker_outage_rate=0.0,
28
+ worker_outage_duration_min=0,
29
+ worker_outage_duration_max=0,
30
+ task_retry_failure_rate=0.0,
31
+ max_task_retries=0,
32
+ grader_target=GraderTarget(
33
+ description=(
34
+ "Reward agents that keep workers utilized and avoid obvious idle time on a "
35
+ "small, low-pressure workflow."
36
+ ),
37
+ score_band_hint="0.8+ means near-greedy scheduling, 0.5 is acceptable, below 0.3 is weak.",
38
+ ),
39
+ ),
40
+ DifficultyPreset.MEDIUM: DifficultyPresetConfig(
41
+ preset=DifficultyPreset.MEDIUM,
42
+ min_tasks=12,
43
+ max_tasks=18,
44
+ edge_probability=0.22,
45
+ duration_min=1,
46
+ duration_max=6,
47
+ priority_min=1,
48
+ priority_max=6,
49
+ worker_count=4,
50
+ deadline_tightness=0.40,
51
+ time_budget_multiplier=1.6,
52
+ worker_outage_rate=0.0,
53
+ worker_outage_duration_min=0,
54
+ worker_outage_duration_max=0,
55
+ task_retry_failure_rate=0.0,
56
+ max_task_retries=0,
57
+ grader_target=GraderTarget(
58
+ description=(
59
+ "Reward agents that balance utilization, deadline adherence, and critical-path "
60
+ "awareness on a moderately branching workflow."
61
+ ),
62
+ score_band_hint="0.75+ is strong, 0.45 to 0.75 is competitive, below 0.3 misses core tradeoffs.",
63
+ ),
64
+ ),
65
+ DifficultyPreset.HARD: DifficultyPresetConfig(
66
+ preset=DifficultyPreset.HARD,
67
+ min_tasks=22,
68
+ max_tasks=36,
69
+ edge_probability=0.37,
70
+ duration_min=2,
71
+ duration_max=9,
72
+ priority_min=1,
73
+ priority_max=8,
74
+ worker_count=2,
75
+ deadline_tightness=0.78,
76
+ time_budget_multiplier=1.45,
77
+ worker_outage_rate=0.2,
78
+ worker_outage_duration_min=2,
79
+ worker_outage_duration_max=4,
80
+ task_retry_failure_rate=0.12,
81
+ max_task_retries=1,
82
+ grader_target=GraderTarget(
83
+ description=(
84
+ "Reward agents that identify and schedule long-running critical tasks early while "
85
+ "protecting high-priority deadlines under frequent worker-capacity bottlenecks."
86
+ ),
87
+ score_band_hint="0.7+ is excellent, 0.4 to 0.7 is competent, below 0.25 is poor planning.",
88
+ ),
89
+ ),
90
+ }
91
+
92
+
93
+ def get_preset_config(preset: DifficultyPreset) -> DifficultyPresetConfig:
94
+ """Return the immutable config for a preset."""
95
+
96
+ return PRESET_CONFIGS[preset].model_copy(deep=True)
pyproject.toml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-workflow_arena"
13
+ version = "0.1.0"
14
+ description = "Workflow Arena environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ "gradio>=5.0.0",
22
+ "plotly>=5.24.0",
23
+ # Environment-specific dependencies
24
+ # Add all dependencies needed for your environment here
25
+ # Examples:
26
+ # "numpy>=1.19.0",
27
+ # "torch>=2.0.0",
28
+ # "gymnasium>=0.29.0",
29
+ # "openspiel>=1.0.0",
30
+ # "smolagents>=1.22.0,<2",
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ dev = [
35
+ "pytest>=8.0.0",
36
+ "pytest-cov>=4.0.0",
37
+ ]
38
+
39
+ [project.scripts]
40
+ # Server entry point - enables running via: uv run --project . server
41
+ # or: python -m workflow_arena.server.app
42
+ server = "workflow_arena.server.app:main"
43
+
44
+ [tool.setuptools]
45
+ include-package-data = true
46
+ packages = ["workflow_arena", "workflow_arena.server"]
47
+ package-dir = { "workflow_arena" = ".", "workflow_arena.server" = "server" }
server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=workflow_arena
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Workflow Arena environment server components."""
8
+
9
+ from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
10
+
11
+ __all__ = ["WorkflowArenaEnvironment"]
server/app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Workflow Arena Environment.
9
+
10
+ This module creates an HTTP server that exposes the WorkflowArenaEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ import gradio as gr
32
+
33
+ try:
34
+ from openenv.core.env_server.http_server import create_app
35
+ except Exception as e: # pragma: no cover
36
+ raise ImportError(
37
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
38
+ ) from e
39
+
40
+ from workflow_arena.models import WorkflowArenaAction, WorkflowArenaObservation
41
+ from workflow_arena.server.ui import create_gradio_app
42
+ from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
43
+
44
+
45
+ # Create the app with web interface and README integration
46
+ app = create_app(
47
+ WorkflowArenaEnvironment,
48
+ WorkflowArenaAction,
49
+ WorkflowArenaObservation,
50
+ env_name="workflow_arena",
51
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
52
+ )
53
+
54
+ # Mount Gradio UI at root — MUST be after all API routes to avoid catchall interference
55
+ _gradio_app = create_gradio_app()
56
+ app = gr.mount_gradio_app(app, _gradio_app, path="/")
57
+
58
+
59
+ def main(host: str = "0.0.0.0", port: int = 8000):
60
+ """
61
+ Entry point for direct execution via uv run or python -m.
62
+
63
+ This function enables running the server without Docker:
64
+ uv run --project . server
65
+ uv run --project . server --port 8001
66
+ python -m workflow_arena.server.app
67
+
68
+ Args:
69
+ host: Host address to bind to (default: "0.0.0.0")
70
+ port: Port number to listen on (default: 8000)
71
+
72
+ For production deployments, consider using uvicorn directly with
73
+ multiple workers:
74
+ uvicorn workflow_arena.server.app:app --workers 4
75
+ """
76
+ import uvicorn
77
+
78
+ uvicorn.run(app, host=host, port=port)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ import argparse
83
+
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("--port", type=int, default=8000)
86
+ args = parser.parse_args()
87
+ if args.port == 8000:
88
+ main()
89
+ else:
90
+ main(port=args.port)
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ gradio>=5.0.0
5
+ plotly>=5.24.0
6
+
server/ui.py ADDED
@@ -0,0 +1,1270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Interactive Gradio UI for WorkflowArena."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ from types import SimpleNamespace
13
+ from typing import Any
14
+
15
+ import gradio as gr
16
+ import plotly.graph_objects as go
17
+
18
+ from workflow_arena.models import DifficultyPreset, TaskStatus, WorkflowActionType, WorkflowArenaAction
19
+ from workflow_arena.presets import get_preset_config
20
+ from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
21
+
22
+
23
+ Session = dict[str, Any]
24
+
25
+ DETAIL_HEADERS = [
26
+ "Task",
27
+ "Priority",
28
+ "Duration",
29
+ "Deadline",
30
+ "Criticality",
31
+ "Slack",
32
+ "Deps",
33
+ "Downstream",
34
+ "Attempts",
35
+ "Start",
36
+ "End",
37
+ ]
38
+
39
+ PRESET_BRIEFS = {
40
+ DifficultyPreset.EASY.value: {
41
+ "label": "Warm-up Flow",
42
+ "summary": "Small DAG, softer deadlines, and fewer traps. Good for learning how dispatch and wait interact.",
43
+ "focus": "Keep workers busy, avoid empty waits, and build intuition for parallel batches.",
44
+ "mechanics": "No hard time budget and no failure events.",
45
+ },
46
+ DifficultyPreset.MEDIUM.value: {
47
+ "label": "Balanced Pressure",
48
+ "summary": "Tighter dependencies and more timing pressure. Scheduling mistakes start to compound.",
49
+ "focus": "Balance urgency, downstream unlocks, and worker utilization.",
50
+ "mechanics": "Adds a fixed time budget and terminal penalty for unfinished work.",
51
+ },
52
+ DifficultyPreset.HARD.value: {
53
+ "label": "Critical Path Sprint",
54
+ "summary": "Dense DAGs, tighter deadlines, and much less room for idle capacity.",
55
+ "focus": "Protect the critical path and use every free slot intentionally.",
56
+ "mechanics": "Adds a tighter time budget plus seeded worker outages and task retry failures.",
57
+ },
58
+ }
59
+
60
+ CSS = """
61
+ .gradio-container {
62
+ background:
63
+ radial-gradient(circle at top left, rgba(216, 116, 76, 0.14), transparent 28%),
64
+ radial-gradient(circle at top right, rgba(201, 157, 92, 0.10), transparent 24%),
65
+ linear-gradient(180deg, #fbf4ea 0%, #f4e7d6 100%);
66
+ color: #2d241c;
67
+ font-family: "IBM Plex Sans", "Avenir Next", "Segoe UI", sans-serif;
68
+ }
69
+ .wa-shell {max-width: 1380px; margin: 0 auto; padding: 10px 10px 30px;}
70
+ .wa-title {margin-bottom: 18px;}
71
+ .wa-title h1 {margin: 0; font-size: 2.6rem; line-height: 1; letter-spacing: -0.04em; color: #3e2618;}
72
+ .wa-title p {margin: 10px 0 0; max-width: 920px; font-size: 1rem; color: #745f50;}
73
+ .wa-hero {display: grid; grid-template-columns: 1.15fr 0.85fr; gap: 18px; margin-bottom: 18px;}
74
+ .wa-card {
75
+ background: rgba(255, 252, 247, 0.92);
76
+ border: 1px solid rgba(139, 110, 84, 0.12);
77
+ border-radius: 24px;
78
+ box-shadow: 0 18px 60px rgba(114, 84, 51, 0.10);
79
+ backdrop-filter: blur(16px);
80
+ }
81
+ .wa-control-card {padding: 20px;}
82
+ .wa-control-card h3,
83
+ .wa-panel h3,
84
+ .wa-playbook h3 {
85
+ margin: 0 0 8px;
86
+ font-size: 0.8rem;
87
+ letter-spacing: 0.12em;
88
+ text-transform: uppercase;
89
+ color: #9f5b33;
90
+ }
91
+ .wa-control-card p,
92
+ .wa-panel p,
93
+ .wa-playbook p {margin: 0; color: #715e50; line-height: 1.5;}
94
+ .wa-control-grid {display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 12px; align-items: end; margin-top: 14px;}
95
+ .wa-control-buttons {display: flex; gap: 10px; align-items: center; margin-top: 12px;}
96
+ .wa-inline-buttons {display: flex; gap: 10px; align-items: center; flex-wrap: wrap;}
97
+ .wa-compact-accordion {margin-top: 14px;}
98
+ .wa-compact-accordion .label-wrap span {font-size: 0.85rem;}
99
+ .wa-problem-box {
100
+ padding: 4px 2px 2px;
101
+ border-radius: 18px;
102
+ }
103
+ .wa-problem-box strong {color: #8d4f2d;}
104
+ .wa-problem-box p {margin: 0 0 8px; color: #6d594c;}
105
+ .wa-problem-box p:last-child {margin-bottom: 0;}
106
+ .wa-preset-card {padding: 20px; min-height: 100%;}
107
+ .wa-preset-card .eyebrow {font-size: 0.75rem; letter-spacing: 0.14em; text-transform: uppercase; color: #b16d3d;}
108
+ .wa-preset-card .name {margin-top: 8px; font-size: 1.7rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
109
+ .wa-preset-card .summary {margin-top: 10px; color: #6e594a; line-height: 1.55;}
110
+ .wa-preset-meta {margin-top: 12px; color: #7b6554; line-height: 1.5;}
111
+ .wa-preset-focus {
112
+ margin-top: 14px;
113
+ padding: 12px 14px;
114
+ border-radius: 18px;
115
+ background: linear-gradient(135deg, rgba(222, 143, 93, 0.12), rgba(242, 210, 171, 0.28));
116
+ border: 1px solid rgba(174, 117, 72, 0.14);
117
+ }
118
+ .wa-topbar {display: grid; grid-template-columns: 1.2fr 1fr 1fr 1fr 1fr 1fr; gap: 12px; margin: 0 0 16px;}
119
+ .wa-stat {
120
+ background: linear-gradient(180deg, #fff8ef 0%, #f9efe2 100%);
121
+ border: 1px solid rgba(168, 130, 95, 0.16);
122
+ border-radius: 20px;
123
+ padding: 16px 18px;
124
+ }
125
+ .wa-stat .label {font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #a56c43;}
126
+ .wa-stat .value {margin-top: 6px; font-size: 1.65rem; font-weight: 700; color: #3e2618;}
127
+ .wa-stat .sub {margin-top: 4px; font-size: 0.82rem; color: #7a6657;}
128
+ .wa-banner {
129
+ border-radius: 24px;
130
+ padding: 18px 20px;
131
+ border: 1px solid rgba(177, 107, 70, 0.18);
132
+ background: linear-gradient(135deg, rgba(245, 186, 144, 0.35), rgba(255, 251, 245, 0.96));
133
+ color: #35271f;
134
+ margin-bottom: 16px;
135
+ }
136
+ .wa-banner.invalid {
137
+ background: linear-gradient(135deg, rgba(247, 202, 196, 0.92), rgba(255, 246, 244, 0.98));
138
+ border-color: rgba(180, 84, 69, 0.22);
139
+ }
140
+ .wa-banner.done {
141
+ background: linear-gradient(135deg, rgba(219, 229, 195, 0.9), rgba(255, 251, 244, 0.98));
142
+ border-color: rgba(112, 139, 90, 0.22);
143
+ }
144
+ .wa-banner-top {display: flex; justify-content: space-between; gap: 16px; align-items: flex-start;}
145
+ .wa-banner .status {
146
+ display: inline-block;
147
+ padding: 5px 10px;
148
+ border-radius: 999px;
149
+ font-size: 0.72rem;
150
+ font-weight: 700;
151
+ letter-spacing: 0.08em;
152
+ text-transform: uppercase;
153
+ background: rgba(164, 97, 59, 0.12);
154
+ }
155
+ .wa-banner .meta {margin-top: 8px; font-size: 0.9rem; color: #7f6654;}
156
+ .wa-banner .note {margin-top: 12px; font-size: 1rem; line-height: 1.5; color: #49372c;}
157
+ .wa-banner-grid {display: grid; grid-template-columns: repeat(4, minmax(0, 1fr)); gap: 10px; margin-top: 14px;}
158
+ .wa-banner-metric {
159
+ padding: 12px 14px;
160
+ border-radius: 18px;
161
+ background: rgba(255, 253, 248, 0.76);
162
+ border: 1px solid rgba(178, 132, 96, 0.12);
163
+ }
164
+ .wa-banner-metric span {display: block; font-size: 0.72rem; letter-spacing: 0.08em; text-transform: uppercase; color: #a16b45;}
165
+ .wa-banner-metric strong {display: block; margin-top: 6px; font-size: 1.1rem; color: #3d281c;}
166
+ .wa-progress {height: 10px; margin-top: 14px; border-radius: 999px; overflow: hidden; background: rgba(120, 83, 54, 0.10);}
167
+ .wa-progress-fill {height: 100%; background: linear-gradient(90deg, #e49157 0%, #f1c27a 100%);}
168
+ .wa-main {display: grid; grid-template-columns: 1.18fr 0.82fr; gap: 18px; align-items: start;}
169
+ .wa-left-stack,
170
+ .wa-right-stack {display: grid; gap: 18px;}
171
+ .wa-panel {padding: 18px 18px 16px;}
172
+ .wa-playbook {padding: 18px;}
173
+ .wa-playbook-header {display: flex; justify-content: space-between; gap: 12px; align-items: center; margin-bottom: 12px;}
174
+ .wa-playbook-title {font-size: 1.4rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
175
+ .wa-chip-row {display: flex; flex-wrap: wrap; gap: 8px; margin-top: 12px;}
176
+ .wa-chip {
177
+ display: inline-flex;
178
+ align-items: center;
179
+ padding: 7px 10px;
180
+ border-radius: 999px;
181
+ background: rgba(172, 121, 80, 0.10);
182
+ color: #7b4f32;
183
+ font-size: 0.84rem;
184
+ font-weight: 600;
185
+ }
186
+ .wa-lane-header {display: flex; justify-content: space-between; gap: 12px; align-items: flex-start; margin-bottom: 8px;}
187
+ .wa-lane-title {font-size: 1.32rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
188
+ .wa-lane-copy {font-size: 0.96rem; color: #78624f;}
189
+ .wa-hint {
190
+ margin-bottom: 14px;
191
+ padding: 12px 14px;
192
+ border-radius: 18px;
193
+ background: rgba(174, 126, 88, 0.08);
194
+ border: 1px solid rgba(174, 126, 88, 0.12);
195
+ color: #6d4a32;
196
+ }
197
+ .wa-card-grid {display: grid; grid-template-columns: repeat(auto-fit, minmax(235px, 1fr)); gap: 12px;}
198
+ .wa-task-card {
199
+ background: linear-gradient(180deg, #fffaf2 0%, #f7ecde 100%);
200
+ border: 1px solid rgba(175, 135, 100, 0.18);
201
+ border-radius: 22px;
202
+ padding: 14px;
203
+ color: #34261d;
204
+ }
205
+ .wa-task-card.running {background: linear-gradient(180deg, #f4ebe0 0%, #ecdcca 100%);}
206
+ .wa-task-card.recommended {outline: 2px solid rgba(226, 145, 87, 0.9); outline-offset: 2px;}
207
+ .wa-task-head {display: flex; justify-content: space-between; gap: 10px; align-items: flex-start; margin-bottom: 10px;}
208
+ .wa-task-name {font-size: 1.08rem; font-weight: 700; color: #3c271b;}
209
+ .wa-badge {
210
+ display: inline-flex;
211
+ align-items: center;
212
+ padding: 4px 8px;
213
+ border-radius: 999px;
214
+ background: rgba(170, 123, 84, 0.12);
215
+ color: #825336;
216
+ font-size: 0.68rem;
217
+ font-weight: 700;
218
+ letter-spacing: 0.08em;
219
+ text-transform: uppercase;
220
+ }
221
+ .wa-badge.urgent {background: rgba(208, 108, 97, 0.16); color: #8c3c35;}
222
+ .wa-badge.active {background: rgba(151, 179, 120, 0.18); color: #5f7142;}
223
+ .wa-badge.recommended {background: rgba(229, 166, 93, 0.20); color: #86501f;}
224
+ .wa-badge.retry {background: rgba(176, 141, 78, 0.18); color: #7a5a22;}
225
+ .wa-task-meta {display: flex; flex-wrap: wrap; gap: 8px; margin-bottom: 10px;}
226
+ .wa-task-meta span {
227
+ padding: 5px 8px;
228
+ border-radius: 999px;
229
+ background: rgba(179, 139, 104, 0.10);
230
+ font-size: 0.76rem;
231
+ color: #77553b;
232
+ }
233
+ .wa-metrics {display: grid; grid-template-columns: repeat(2, minmax(0, 1fr)); gap: 10px 12px;}
234
+ .wa-metric span {display: block; font-size: 0.68rem; letter-spacing: 0.08em; text-transform: uppercase; color: #a26c45;}
235
+ .wa-metric strong {display: block; margin-top: 4px; font-size: 0.98rem; color: #35261d;}
236
+ .wa-empty {
237
+ padding: 20px;
238
+ border-radius: 20px;
239
+ border: 1px dashed rgba(176, 133, 98, 0.26);
240
+ background: rgba(178, 143, 112, 0.06);
241
+ color: #7b6656;
242
+ text-align: center;
243
+ }
244
+ .wa-action-row {display: flex; flex-wrap: wrap; gap: 10px; margin-top: 14px;}
245
+ .wa-button-primary button {background: linear-gradient(135deg, #d97b4b, #c95f34) !important; color: #fff7f0 !important; border: none !important;}
246
+ .wa-button-secondary button {background: #8f5b3b !important; color: #fff8f2 !important; border: none !important;}
247
+ .wa-button-ghost button {background: rgba(180, 132, 96, 0.08) !important; color: #7a4d31 !important; border: 1px solid rgba(180, 132, 96, 0.16) !important;}
248
+ .wa-plot-wrap {padding: 10px 10px 2px;}
249
+ .wa-footer-stack {display: grid; gap: 18px; margin-top: 18px;}
250
+ .wa-accordion {border-radius: 20px !important; overflow: hidden;}
251
+ @media (max-width: 1080px) {
252
+ .wa-hero,
253
+ .wa-main {grid-template-columns: 1fr;}
254
+ .wa-topbar {grid-template-columns: repeat(2, minmax(0, 1fr));}
255
+ }
256
+ @media (max-width: 760px) {
257
+ .wa-control-grid {grid-template-columns: 1fr;}
258
+ .wa-banner-grid {grid-template-columns: repeat(2, minmax(0, 1fr));}
259
+ .wa-topbar {grid-template-columns: 1fr;}
260
+ }
261
+ """
262
+
263
+
264
+ def _blank_session() -> Session:
265
+ return {"env": WorkflowArenaEnvironment(), "observation": None, "history": []}
266
+
267
+
268
+ def _fmt_num(value: Any, digits: int = 3) -> str:
269
+ if value is None:
270
+ return "—"
271
+ if isinstance(value, float):
272
+ return f"{value:.{digits}f}"
273
+ return str(value)
274
+
275
+
276
+ def _preset_html(preset: str) -> str:
277
+ brief = PRESET_BRIEFS.get(preset, PRESET_BRIEFS[DifficultyPreset.EASY.value])
278
+ preset_config = get_preset_config(DifficultyPreset(preset))
279
+ budget_note = (
280
+ "No fixed time budget."
281
+ if preset_config.time_budget_multiplier is None
282
+ else f"Time budget uses {preset_config.time_budget_multiplier:.2f}x the lower-bound makespan."
283
+ )
284
+ return (
285
+ '<div class="wa-preset-card wa-card">'
286
+ '<div class="eyebrow">Preset brief</div>'
287
+ f'<div class="name">{brief["label"]}</div>'
288
+ f'<div class="summary">{brief["summary"]}</div>'
289
+ f'<div class="wa-preset-meta">{budget_note} {brief["mechanics"]}</div>'
290
+ f'<div class="wa-preset-focus"><strong>What matters now:</strong> {brief["focus"]}</div>'
291
+ "</div>"
292
+ )
293
+
294
+
295
+ def _status_text(observation: Any) -> tuple[str, str]:
296
+ if observation.validation_error:
297
+ return "Invalid action", "bad"
298
+ if observation.done and observation.termination_reason:
299
+ return "Episode terminated", "bad"
300
+ if observation.done:
301
+ return "Workflow completed", "good"
302
+ if observation.free_workers == 0 and observation.running_tasks:
303
+ return "Wait required", ""
304
+ if observation.ready_tasks:
305
+ return "Ready to dispatch", ""
306
+ return "Waiting on completions", ""
307
+
308
+
309
+ def _recommended_task_ids(observation: Any) -> list[str]:
310
+ if observation is None or observation.done or observation.free_workers <= 0:
311
+ return []
312
+ ready_tasks = list(observation.ready_tasks)
313
+ if not ready_tasks:
314
+ return []
315
+ time_remaining = observation.time_remaining
316
+ ranked = sorted(
317
+ ready_tasks,
318
+ key=lambda task: (
319
+ time_remaining is not None and task.duration > time_remaining,
320
+ max(0, task.duration - time_remaining) if time_remaining is not None else 0,
321
+ task.slack if task.slack is not None else 1_000_000,
322
+ task.deadline if task.deadline is not None else 1_000_000,
323
+ -(task.criticality or 0.0),
324
+ -task.priority,
325
+ task.duration,
326
+ task.task_id,
327
+ ),
328
+ )
329
+ return [task.task_id for task in ranked[: observation.free_workers]]
330
+
331
+
332
+ def _dispatch_window(observation: Any) -> tuple[int, int, int]:
333
+ ready_count = len(observation.ready_tasks)
334
+ free_workers = max(0, observation.free_workers)
335
+ dispatchable_now = min(ready_count, free_workers)
336
+ overflow_ready = max(0, ready_count - free_workers)
337
+ return ready_count, dispatchable_now, overflow_ready
338
+
339
+
340
+ def _topbar_html(observation: Any) -> str:
341
+ completed = observation.progress.completed
342
+ total = max(1, observation.progress.total)
343
+ score = observation.benchmark_score
344
+ if score is None:
345
+ score = observation.success_metrics.benchmark_score
346
+ time_sub = (
347
+ f"{observation.time_remaining} remaining"
348
+ if observation.time_remaining is not None
349
+ else "simulation clock"
350
+ )
351
+ worker_sub = (
352
+ f"idle / usable of {observation.total_workers}"
353
+ if getattr(observation, "effective_workers", observation.total_workers) != observation.total_workers
354
+ else "free / total"
355
+ )
356
+ cards = [
357
+ ("State", _status_text(observation)[0], f"{observation.progress.ready} ready / {observation.progress.running} running"),
358
+ (
359
+ "Workers",
360
+ f"{observation.free_workers}/{getattr(observation, 'effective_workers', observation.total_workers)}",
361
+ worker_sub,
362
+ ),
363
+ ("Completed", f"{completed}/{total}", f"{round(100 * completed / total, 1)}% finished"),
364
+ ("Reward", _fmt_num(observation.cumulative_reward, 3), "cumulative"),
365
+ ("Time", observation.current_time, time_sub),
366
+ ("Score", _fmt_num(score, 3), "terminal if done"),
367
+ ]
368
+ return '<div class="wa-topbar">' + "".join(
369
+ f'<div class="wa-stat"><div class="label">{label}</div><div class="value">{value}</div><div class="sub">{sub}</div></div>'
370
+ for label, value, sub in cards
371
+ ) + "</div>"
372
+
373
+
374
+ def _banner_html(observation: Any) -> str:
375
+ completed = observation.progress.completed
376
+ total = max(1, observation.progress.total)
377
+ progress_pct = round(100 * completed / total, 1)
378
+ status_text, status_kind = _status_text(observation)
379
+ banner_class = "wa-banner"
380
+ if status_kind == "bad":
381
+ banner_class += " invalid"
382
+ elif status_kind == "good":
383
+ banner_class += " done"
384
+
385
+ failure_note = _failure_summary(observation)
386
+ note = observation.note or "No environment note."
387
+ if observation.validation_error:
388
+ note = f"{note} {observation.validation_error}"
389
+ if failure_note:
390
+ note = f"{note} {failure_note}"
391
+
392
+ score = observation.benchmark_score
393
+ if score is None:
394
+ score = observation.success_metrics.benchmark_score
395
+
396
+ metric_cards = [
397
+ ("Ready", observation.progress.ready),
398
+ ("Running", observation.progress.running),
399
+ ("Workers", f"{getattr(observation, 'effective_workers', observation.total_workers)}/{observation.total_workers}"),
400
+ (
401
+ "Time Left",
402
+ observation.time_remaining if observation.time_remaining is not None else "—",
403
+ ),
404
+ ]
405
+
406
+ return (
407
+ f'<div class="{banner_class}">'
408
+ '<div class="wa-banner-top">'
409
+ '<div>'
410
+ f'<span class="status">{status_text}</span>'
411
+ f'<div class="meta">Preset: {observation.config.preset.value} • Seed: {observation.config.seed} • Workers: {observation.total_workers}</div>'
412
+ f'<div class="note">{note}</div>'
413
+ "</div>"
414
+ f'<div class="meta">{completed}/{total} complete</div>'
415
+ "</div>"
416
+ '<div class="wa-banner-grid">'
417
+ + "".join(
418
+ f'<div class="wa-banner-metric"><span>{label}</span><strong>{value}</strong></div>'
419
+ for label, value in metric_cards
420
+ )
421
+ + "</div>"
422
+ f'<div class="wa-progress"><div class="wa-progress-fill" style="width:{progress_pct:.1f}%"></div></div>'
423
+ "</div>"
424
+ )
425
+
426
+
427
+ def _planner_html(observation: Any) -> str:
428
+ recommended = _recommended_task_ids(observation)
429
+ ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
430
+ if observation.done:
431
+ title = "Episode finished"
432
+ body = "Reset for another episode or inspect the final timeline and reward trace below."
433
+ elif observation.validation_error:
434
+ title = "Fix the last move"
435
+ body = observation.validation_error
436
+ elif observation.free_workers == 0 and observation.running_tasks:
437
+ title = "Advance time"
438
+ body = "All workers are occupied. Waiting is the only legal move until the next task completes."
439
+ elif recommended:
440
+ title = f"Dispatch {', '.join(recommended)}"
441
+ body = (
442
+ "These tasks minimize slack first, then prefer tighter deadlines, stronger criticality, and higher priority. "
443
+ f"The recommendation is capped at `{dispatchable_now}` because only `{observation.free_workers}` worker"
444
+ f"{'s are' if observation.free_workers != 1 else ' is'} free right now."
445
+ )
446
+ else:
447
+ title = "Hold for completions"
448
+ body = "No ready work is available. Wait until dependencies unlock new tasks."
449
+
450
+ chips = [
451
+ f"free workers: {observation.free_workers}",
452
+ f"usable workers: {getattr(observation, 'effective_workers', observation.total_workers)}",
453
+ f"ready queue: {ready_count}",
454
+ f"dispatchable now: {dispatchable_now}",
455
+ f"last reward: {_fmt_num(observation.reward if hasattr(observation, 'reward') else 0.0, 3)}",
456
+ ]
457
+ if observation.time_remaining is not None:
458
+ chips.append(f"time remaining: {observation.time_remaining}")
459
+ if overflow_ready:
460
+ chips.append(f"queued beyond capacity: {overflow_ready}")
461
+ if observation.running_tasks:
462
+ next_finish = min(task.end_time or observation.current_time for task in observation.running_tasks)
463
+ chips.append(f"next completion: t={next_finish}")
464
+ if observation.degraded_workers:
465
+ chips.append(f"worker outage: -{observation.degraded_workers} usable")
466
+
467
+ return (
468
+ '<div class="wa-playbook wa-card">'
469
+ '<div class="wa-playbook-header">'
470
+ '<div>'
471
+ '<h3>Decision support</h3>'
472
+ f'<div class="wa-playbook-title">{title}</div>'
473
+ "</div>"
474
+ f'<div class="wa-chip">{_status_text(observation)[0]}</div>'
475
+ "</div>"
476
+ f'<p>{body}</p>'
477
+ '<div class="wa-chip-row">'
478
+ + "".join(f'<div class="wa-chip">{chip}</div>' for chip in chips)
479
+ + "</div>"
480
+ "</div>"
481
+ )
482
+
483
+
484
+ def _capacity_hint(observation: Any) -> str:
485
+ ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
486
+ if observation.done:
487
+ return "Episode finished. Review the schedule or reset to try another seed."
488
+ if observation.validation_error:
489
+ return (
490
+ f"Last action was rejected. Select at most {observation.free_workers} ready "
491
+ f"task{'s' if observation.free_workers != 1 else ''}."
492
+ )
493
+ if observation.free_workers == 0 and observation.running_tasks:
494
+ return "All workers are busy. Use Wait to jump to the next completion."
495
+ if not observation.ready_tasks:
496
+ return "No ready tasks available right now. Wait until dependencies unlock more work."
497
+ overflow_suffix = f" `{overflow_ready}` ready task(s) will stay queued." if overflow_ready else ""
498
+ return (
499
+ f"{ready_count} ready task{'s' if ready_count != 1 else ''}. "
500
+ f"You can dispatch up to {dispatchable_now} right now.{overflow_suffix}"
501
+ )
502
+
503
+
504
+ def _task_badges(task: Any, *, running: bool = False, recommended: bool = False) -> str:
505
+ badges: list[str] = []
506
+ if task.deadline is not None and task.slack is not None and task.slack <= 1:
507
+ badges.append('<span class="wa-badge urgent">Urgent</span>')
508
+ if getattr(task, "attempt_count", 0) > 0:
509
+ badges.append(
510
+ f'<span class="wa-badge retry">Retry {getattr(task, "attempt_count", 0) + 1}</span>'
511
+ )
512
+ if recommended:
513
+ badges.append('<span class="wa-badge recommended">Recommended</span>')
514
+ if running:
515
+ badges.append('<span class="wa-badge active">Running</span>')
516
+ if not badges:
517
+ badges.append('<span class="wa-badge">Ready</span>')
518
+ return "".join(badges)
519
+
520
+
521
+ def _task_card(task: Any, *, running: bool = False, recommended: bool = False) -> str:
522
+ deps = ", ".join(task.dependencies) if task.dependencies else "None"
523
+ deadline = task.deadline if task.deadline is not None else "—"
524
+ start = task.start_time if task.start_time is not None else "—"
525
+ end = task.end_time if task.end_time is not None else "—"
526
+ classes = ["wa-task-card"]
527
+ if running:
528
+ classes.append("running")
529
+ if recommended:
530
+ classes.append("recommended")
531
+ return (
532
+ f'<div class="{" ".join(classes)}">'
533
+ '<div class="wa-task-head">'
534
+ f'<div class="wa-task-name">{task.task_id}</div>'
535
+ f'<div>{_task_badges(task, running=running, recommended=recommended)}</div>'
536
+ "</div>"
537
+ f'<div class="wa-task-meta"><span>deps: {deps}</span><span>downstream: {task.downstream_count}</span><span>attempts: {getattr(task, "attempt_count", 0) + 1}</span></div>'
538
+ '<div class="wa-metrics">'
539
+ f'<div class="wa-metric"><span>Deadline</span><strong>{deadline}</strong></div>'
540
+ f'<div class="wa-metric"><span>Duration</span><strong>{task.duration}</strong></div>'
541
+ f'<div class="wa-metric"><span>Priority</span><strong>{task.priority}</strong></div>'
542
+ f'<div class="wa-metric"><span>Criticality</span><strong>{_fmt_num(task.criticality, 3)}</strong></div>'
543
+ f'<div class="wa-metric"><span>Slack</span><strong>{_fmt_num(task.slack, 1)}</strong></div>'
544
+ f'<div class="wa-metric"><span>{"Finish" if running else "Start"}</span><strong>{end if running else start}</strong></div>'
545
+ "</div>"
546
+ "</div>"
547
+ )
548
+
549
+
550
+ def _cards_html(tasks: list[Any], *, running: bool = False, recommended_ids: set[str] | None = None) -> str:
551
+ if not tasks:
552
+ message = "No tasks in this lane yet." if running else "No ready tasks available."
553
+ return f'<div class="wa-empty">{message}</div>'
554
+ recommended_ids = recommended_ids or set()
555
+ return '<div class="wa-card-grid">' + "".join(
556
+ _task_card(task, running=running, recommended=task.task_id in recommended_ids)
557
+ for task in tasks
558
+ ) + "</div>"
559
+
560
+
561
+ def _timeline_figure(observation: Any) -> go.Figure:
562
+ fig = go.Figure()
563
+
564
+ completed = sorted(observation.completed_tasks, key=lambda task: task.task_id)
565
+ running = sorted(observation.running_tasks, key=lambda task: task.task_id)
566
+ ready = sorted(observation.ready_tasks, key=lambda task: task.task_id)
567
+
568
+ timeline_tasks = completed + running
569
+ task_ids = [task.task_id for task in timeline_tasks] + [task.task_id for task in ready]
570
+
571
+ if completed:
572
+ fig.add_trace(
573
+ go.Bar(
574
+ x=[max(0, (task.end_time or 0) - (task.start_time or 0)) for task in completed],
575
+ y=[task.task_id for task in completed],
576
+ base=[task.start_time or 0 for task in completed],
577
+ orientation="h",
578
+ name="Completed",
579
+ marker_color="#85c88a",
580
+ hovertemplate=(
581
+ "<b>%{y}</b><br>Status: Completed<br>Start: %{base}<br>"
582
+ "Duration: %{x}<extra></extra>"
583
+ ),
584
+ )
585
+ )
586
+
587
+ if running:
588
+ fig.add_trace(
589
+ go.Bar(
590
+ x=[
591
+ max(
592
+ 0,
593
+ (task.end_time or observation.current_time) - (task.start_time or observation.current_time),
594
+ )
595
+ for task in running
596
+ ],
597
+ y=[task.task_id for task in running],
598
+ base=[task.start_time or observation.current_time for task in running],
599
+ orientation="h",
600
+ name="Running",
601
+ marker_color="#d88a5b",
602
+ hovertemplate=(
603
+ "<b>%{y}</b><br>Status: Running<br>Start: %{base}<br>"
604
+ "Allocated span: %{x}<extra></extra>"
605
+ ),
606
+ )
607
+ )
608
+
609
+ if ready:
610
+ fig.add_trace(
611
+ go.Scatter(
612
+ x=[observation.current_time] * len(ready),
613
+ y=[task.task_id for task in ready],
614
+ mode="markers",
615
+ name="Ready",
616
+ marker=dict(color="#9e6a43", size=11, symbol="diamond"),
617
+ customdata=[[task.deadline, task.priority, task.duration] for task in ready],
618
+ hovertemplate=(
619
+ "<b>%{y}</b><br>Status: Ready<br>Current time: %{x}<br>"
620
+ "Deadline: %{customdata[0]}<br>Priority: %{customdata[1]}<br>"
621
+ "Duration: %{customdata[2]}<extra></extra>"
622
+ ),
623
+ )
624
+ )
625
+
626
+ if not task_ids:
627
+ task_ids = ["No tasks yet"]
628
+ fig.add_annotation(
629
+ text="Reset an episode to populate the workflow timeline.",
630
+ x=0.5,
631
+ y=0.5,
632
+ xref="paper",
633
+ yref="paper",
634
+ showarrow=False,
635
+ font=dict(color="#8c6f58", size=14),
636
+ )
637
+
638
+ horizon_candidates = [observation.current_time + 1]
639
+ horizon_candidates.extend(task.end_time or 0 for task in completed)
640
+ horizon_candidates.extend(task.end_time or observation.current_time for task in running)
641
+ x_max = max(horizon_candidates) + 1
642
+
643
+ fig.add_vline(
644
+ x=observation.current_time,
645
+ line_width=2,
646
+ line_dash="dash",
647
+ line_color="#9f6b48",
648
+ annotation_text="Now",
649
+ annotation_position="top left",
650
+ )
651
+
652
+ fig.update_layout(
653
+ barmode="overlay",
654
+ height=max(280, 90 + 34 * len(task_ids)),
655
+ margin=dict(l=10, r=10, t=44, b=18),
656
+ paper_bgcolor="#ffffff",
657
+ plot_bgcolor="#ffffff",
658
+ font=dict(color="#4d382b", family="IBM Plex Sans, Arial, sans-serif"),
659
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0),
660
+ title=dict(text="Workflow Timeline", x=0.02, font=dict(size=18)),
661
+ xaxis=dict(
662
+ title="Simulated Time",
663
+ range=[0, x_max],
664
+ gridcolor="#eadfce",
665
+ zeroline=False,
666
+ linecolor="#cfb79c",
667
+ title_font=dict(color="#6c4a33"),
668
+ tickfont=dict(color="#6c4a33"),
669
+ ),
670
+ yaxis=dict(
671
+ title="Tasks",
672
+ categoryorder="array",
673
+ categoryarray=list(reversed(task_ids)),
674
+ gridcolor="#f2e8da",
675
+ linecolor="#cfb79c",
676
+ title_font=dict(color="#6c4a33"),
677
+ tickfont=dict(color="#6c4a33"),
678
+ ),
679
+ )
680
+ return fig
681
+
682
+
683
+ def _detail_rows(tasks: list[Any]) -> list[list[Any]]:
684
+ return [
685
+ [
686
+ task.task_id,
687
+ task.priority,
688
+ task.duration,
689
+ task.deadline if task.deadline is not None else "—",
690
+ _fmt_num(task.criticality, 3),
691
+ _fmt_num(task.slack, 1),
692
+ len(task.dependencies),
693
+ task.downstream_count,
694
+ getattr(task, "attempt_count", 0) + 1,
695
+ task.start_time if task.start_time is not None else "—",
696
+ task.end_time if task.end_time is not None else "—",
697
+ ]
698
+ for task in tasks
699
+ ]
700
+
701
+
702
+ def _failure_summary(observation: Any) -> str:
703
+ events = getattr(observation, "recent_failure_events", []) or []
704
+ if not events:
705
+ return ""
706
+ return " ".join(event.detail for event in events if getattr(event, "detail", ""))
707
+
708
+
709
+ def _selection_markdown(selected_task_ids: list[str], observation: Any) -> str:
710
+ if observation is None:
711
+ return "No episode yet. Reset an episode to start building a dispatch batch."
712
+
713
+ task_map = {task.task_id: task for task in observation.ready_tasks}
714
+ selected_tasks = [task_map[task_id] for task_id in selected_task_ids if task_id in task_map]
715
+ capacity = max(0, observation.free_workers)
716
+ ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
717
+
718
+ if not selected_tasks:
719
+ recommended = _recommended_task_ids(observation)
720
+ if not recommended:
721
+ return (
722
+ f"**Dispatch builder**\n\nReady queue: `{ready_count}`. "
723
+ f"Dispatchable now: `{dispatchable_now}`."
724
+ )
725
+ overflow_suffix = f" `{overflow_ready}` ready task(s) stay queued after dispatch." if overflow_ready else ""
726
+ return (
727
+ f"**Dispatch builder**\n\nNo tasks selected yet. Ready queue: `{ready_count}`. "
728
+ f"Recommended batch: `{', '.join(recommended)}`. Dispatch cap now: `{dispatchable_now}`.{overflow_suffix}"
729
+ )
730
+
731
+ total_priority = sum(task.priority for task in selected_tasks)
732
+ shortest_finish = observation.current_time + min(task.duration for task in selected_tasks)
733
+ longest_finish = observation.current_time + max(task.duration for task in selected_tasks)
734
+ warnings: list[str] = []
735
+ if len(selected_tasks) > capacity:
736
+ warnings.append(f"Selection exceeds capacity by `{len(selected_tasks) - capacity}`.")
737
+
738
+ warning_text = "\n\n" + " ".join(warnings) if warnings else ""
739
+ return (
740
+ f"**Dispatch builder**\n\nSelected `{len(selected_tasks)}` task(s) for `{capacity}` free slot(s). "
741
+ f"Priority sum: `{total_priority}`. Earliest completion: `t={shortest_finish}`. "
742
+ f"Longest in-flight span: `t={longest_finish}`.{warning_text}"
743
+ )
744
+
745
+
746
+ def _reward_markdown(observation: Any) -> str:
747
+ breakdown = observation.last_reward_breakdown
748
+ rows = [
749
+ ("completion", breakdown.completion_reward),
750
+ ("utilization", breakdown.utilization_reward),
751
+ ("deadline", breakdown.deadline_reward),
752
+ ("criticality", breakdown.criticality_reward),
753
+ ("idle", breakdown.idle_penalty),
754
+ ("invalid", breakdown.invalid_action_penalty),
755
+ ("terminal", breakdown.terminal_makespan_score),
756
+ ("unfinished", breakdown.unfinished_task_penalty),
757
+ ]
758
+ lines = ["| Channel | Value |", "| --- | ---: |"]
759
+ lines.extend(f"| {label} | {value:.3f} |" for label, value in rows)
760
+ return "\n".join(lines)
761
+
762
+
763
+ def _history_markdown(history: list[dict[str, Any]]) -> str:
764
+ if not history:
765
+ return "No actions yet."
766
+ lines: list[str] = []
767
+ for item in history[-12:]:
768
+ reward = _fmt_num(item.get("reward"), 3)
769
+ suffix = f" • error: `{item['error']}`" if item.get("error") else ""
770
+ note = item.get("note") or ""
771
+ lines.append(f"**{item['label']}** at `t={item['time']}` • reward `{reward}`{suffix} \n{note}")
772
+ return "\n\n".join(reversed(lines))
773
+
774
+
775
+ def _blank_observation_view() -> Any:
776
+ return SimpleNamespace(
777
+ reward=0.0,
778
+ progress=SimpleNamespace(completed=0, ready=0, running=0, blocked=0, total=1),
779
+ benchmark_score=None,
780
+ success_metrics=SimpleNamespace(benchmark_score=None, unfinished_task_count=0),
781
+ free_workers=0,
782
+ effective_workers=0,
783
+ degraded_workers=0,
784
+ total_workers=0,
785
+ time_budget=None,
786
+ time_remaining=None,
787
+ cumulative_reward=0.0,
788
+ current_time=0,
789
+ done=False,
790
+ termination_reason=None,
791
+ validation_error=None,
792
+ completed_tasks=[],
793
+ ready_tasks=[],
794
+ running_tasks=[],
795
+ blocked_tasks=[],
796
+ recent_failure_events=[],
797
+ note="Reset an episode to start scheduling.",
798
+ config=SimpleNamespace(preset=SimpleNamespace(value=DifficultyPreset.EASY.value), seed=0),
799
+ )
800
+
801
+
802
+ def _empty_updates(session: Session):
803
+ empty_rows: list[list[Any]] = []
804
+ blank = _blank_observation_view()
805
+ return (
806
+ session,
807
+ _preset_html(DifficultyPreset.EASY.value),
808
+ _topbar_html(blank),
809
+ _banner_html(blank),
810
+ _planner_html(blank),
811
+ "No episode yet.",
812
+ _selection_markdown([], blank),
813
+ '<div class="wa-empty">Reset an episode to see ready tasks.</div>',
814
+ gr.update(choices=[], value=[]),
815
+ gr.update(interactive=False),
816
+ gr.update(interactive=False),
817
+ gr.update(interactive=False),
818
+ gr.update(interactive=False),
819
+ gr.update(interactive=False),
820
+ '<div class="wa-empty">Running tasks will appear here after dispatch.</div>',
821
+ _timeline_figure(blank),
822
+ "| Channel | Value |\n| --- | ---: |\n| — | — |",
823
+ "No actions yet.",
824
+ empty_rows,
825
+ empty_rows,
826
+ empty_rows,
827
+ )
828
+
829
+
830
+ def _render(session: Session):
831
+ observation = session.get("observation")
832
+ env = session.get("env")
833
+ history = session.get("history", [])
834
+ if observation is None:
835
+ return _empty_updates(session)
836
+
837
+ if env is None:
838
+ completed_rows = _detail_rows(observation.completed_tasks)
839
+ blocked_rows = _detail_rows(observation.blocked_tasks)
840
+ else:
841
+ completed_rows = _detail_rows(env.debug_task_views_for_status(TaskStatus.COMPLETED))
842
+ blocked_rows = _detail_rows(env.debug_task_views_for_status(TaskStatus.BLOCKED))
843
+
844
+ ready_choices = [task.task_id for task in observation.ready_tasks]
845
+ recommended_ids = _recommended_task_ids(observation)
846
+ can_recommend = bool(recommended_ids) and not observation.done
847
+ can_wait = bool(observation.running_tasks) and not observation.done
848
+ can_clear = bool(ready_choices) and not observation.done
849
+
850
+ return (
851
+ session,
852
+ _preset_html(observation.config.preset.value),
853
+ _topbar_html(observation),
854
+ _banner_html(observation),
855
+ _planner_html(observation),
856
+ _capacity_hint(observation),
857
+ _selection_markdown([], observation),
858
+ _cards_html(observation.ready_tasks, running=False, recommended_ids=set(recommended_ids)),
859
+ gr.update(choices=ready_choices, value=[]),
860
+ gr.update(interactive=False),
861
+ gr.update(interactive=can_wait),
862
+ gr.update(interactive=can_recommend),
863
+ gr.update(interactive=can_recommend),
864
+ gr.update(interactive=can_clear),
865
+ _cards_html(observation.running_tasks, running=True),
866
+ _timeline_figure(observation),
867
+ _reward_markdown(observation),
868
+ _history_markdown(history),
869
+ empty_rows := _detail_rows([]),
870
+ completed_rows,
871
+ blocked_rows,
872
+ )
873
+
874
+
875
+ def _append_history(
876
+ session: Session,
877
+ label: str,
878
+ observation: Any,
879
+ *,
880
+ reward: float | None = None,
881
+ error: str | None = None,
882
+ ) -> Session:
883
+ history = list(session.get("history", []))
884
+ history.append(
885
+ {
886
+ "label": label,
887
+ "time": observation.current_time,
888
+ "reward": reward,
889
+ "error": error,
890
+ "note": observation.note,
891
+ }
892
+ )
893
+ session["history"] = history
894
+ return session
895
+
896
+
897
+ def _reset(preset: str, seed: float, worker_count: float, session: Session):
898
+ env = session.get("env") or WorkflowArenaEnvironment()
899
+ observation = env.reset(
900
+ preset=preset,
901
+ seed=int(seed),
902
+ worker_count=int(worker_count),
903
+ )
904
+ next_session = {"env": env, "observation": observation, "history": []}
905
+ next_session = _append_history(
906
+ next_session,
907
+ f"reset • preset `{preset}` • workers `{int(worker_count)}`",
908
+ observation,
909
+ reward=0.0,
910
+ error=None,
911
+ )
912
+ return _render(next_session)
913
+
914
+
915
+ def _dispatch(selected_task_ids: list[str], session: Session):
916
+ observation = session.get("observation")
917
+ env = session.get("env")
918
+ if env is None or observation is None:
919
+ return _render(_blank_session())
920
+
921
+ action = WorkflowArenaAction(
922
+ action_type=WorkflowActionType.DISPATCH,
923
+ task_ids=selected_task_ids,
924
+ )
925
+ next_observation = env.step(action)
926
+ next_session = {
927
+ "env": env,
928
+ "observation": next_observation,
929
+ "history": session.get("history", []),
930
+ }
931
+ label = "dispatch " + (", ".join(selected_task_ids) if selected_task_ids else "(none)")
932
+ next_session = _append_history(
933
+ next_session,
934
+ label,
935
+ next_observation,
936
+ reward=next_observation.reward,
937
+ error=next_observation.validation_error,
938
+ )
939
+ return _render(next_session)
940
+
941
+
942
+ def _dispatch_recommended(session: Session):
943
+ observation = session.get("observation")
944
+ if observation is None:
945
+ return _render(_blank_session())
946
+ return _dispatch(_recommended_task_ids(observation), session)
947
+
948
+
949
+ def _wait(session: Session):
950
+ observation = session.get("observation")
951
+ env = session.get("env")
952
+ if env is None or observation is None:
953
+ return _render(_blank_session())
954
+
955
+ action = WorkflowArenaAction(
956
+ action_type=WorkflowActionType.WAIT,
957
+ task_ids=[],
958
+ )
959
+ next_observation = env.step(action)
960
+ next_session = {
961
+ "env": env,
962
+ "observation": next_observation,
963
+ "history": session.get("history", []),
964
+ }
965
+ next_session = _append_history(
966
+ next_session,
967
+ "wait",
968
+ next_observation,
969
+ reward=next_observation.reward,
970
+ error=next_observation.validation_error,
971
+ )
972
+ return _render(next_session)
973
+
974
+
975
+ def _update_selection(selected_task_ids: list[str], session: Session):
976
+ observation = session.get("observation")
977
+ if observation is None:
978
+ return "No episode yet.", [], gr.update(interactive=False)
979
+
980
+ task_map = {task.task_id: task for task in observation.ready_tasks}
981
+ selected_tasks = [task_map[task_id] for task_id in selected_task_ids if task_id in task_map]
982
+ can_dispatch = (
983
+ bool(selected_tasks)
984
+ and len(selected_tasks) == len(selected_task_ids)
985
+ and len(selected_task_ids) <= observation.free_workers
986
+ and not observation.done
987
+ )
988
+ return (
989
+ _selection_markdown(selected_task_ids, observation),
990
+ _detail_rows(selected_tasks),
991
+ gr.update(interactive=can_dispatch),
992
+ )
993
+
994
+
995
+ def _select_recommended(session: Session):
996
+ observation = session.get("observation")
997
+ if observation is None:
998
+ return gr.update(value=[]), "No episode yet.", [], gr.update(interactive=False)
999
+ recommended_ids = _recommended_task_ids(observation)
1000
+ task_map = {task.task_id: task for task in observation.ready_tasks}
1001
+ selected_tasks = [task_map[task_id] for task_id in recommended_ids if task_id in task_map]
1002
+ return (
1003
+ gr.update(value=recommended_ids),
1004
+ _selection_markdown(recommended_ids, observation),
1005
+ _detail_rows(selected_tasks),
1006
+ gr.update(interactive=bool(recommended_ids) and not observation.done),
1007
+ )
1008
+
1009
+
1010
+ def _clear_selection(session: Session):
1011
+ observation = session.get("observation")
1012
+ return (
1013
+ gr.update(value=[]),
1014
+ _selection_markdown([], observation),
1015
+ [],
1016
+ gr.update(interactive=False),
1017
+ )
1018
+
1019
+
1020
+ def _random_seed() -> int:
1021
+ return random.randint(0, 999_999)
1022
+
1023
+
1024
+ def _preset_controls_update(preset: str):
1025
+ preset_config = get_preset_config(DifficultyPreset(preset))
1026
+ return _preset_html(preset), gr.update(value=preset_config.worker_count)
1027
+
1028
+
1029
+ def create_gradio_app() -> gr.Blocks:
1030
+ with gr.Blocks(title="WorkflowArena") as demo:
1031
+ session = gr.State(_blank_session())
1032
+
1033
+ with gr.Column(elem_classes=["wa-shell"]):
1034
+ gr.HTML(f"<style>{CSS}</style>")
1035
+ gr.HTML(
1036
+ """
1037
+ <div class="wa-title">
1038
+ <h1>WorkflowArena</h1>
1039
+ <p>Run a workflow episode like a control room instead of a raw form. Reset a seeded DAG, inspect urgency and capacity, build a legal dispatch batch, then advance time when workers are saturated.</p>
1040
+ </div>
1041
+ """
1042
+ )
1043
+
1044
+ with gr.Row(elem_classes=["wa-hero"]):
1045
+ with gr.Column(elem_classes=["wa-control-card", "wa-card"]):
1046
+ gr.HTML("<h3>Episode controls</h3><p>Change the preset, seed, or worker count, then reset to generate a new scheduling problem.</p>")
1047
+ with gr.Accordion("Problem Framing", open=False, elem_classes=["wa-accordion", "wa-compact-accordion"]):
1048
+ gr.HTML(
1049
+ """
1050
+ <div class="wa-problem-box">
1051
+ <p><strong>What this problem is:</strong> You are scheduling a workflow where tasks depend on each other and workers are limited. At every step, the legal move is either to dispatch ready tasks to free workers or wait for the next completion event.</p>
1052
+ <p><strong>What good play looks like:</strong> Finish urgent and high-value work on time, keep workers utilized, and avoid delaying the critical path. Higher difficulties add a time budget and, in hard mode, failure events that reduce usable capacity or force retries.</p>
1053
+ </div>
1054
+ """
1055
+ )
1056
+ with gr.Row(elem_classes=["wa-control-grid"]):
1057
+ preset = gr.Dropdown(
1058
+ label="Preset",
1059
+ choices=[preset.value for preset in DifficultyPreset],
1060
+ value=DifficultyPreset.EASY.value,
1061
+ interactive=True,
1062
+ )
1063
+ seed = gr.Number(label="Seed", value=0, precision=0, minimum=0)
1064
+ workers = gr.Slider(
1065
+ minimum=1,
1066
+ maximum=6,
1067
+ step=1,
1068
+ value=3,
1069
+ label="Workers",
1070
+ interactive=True,
1071
+ )
1072
+ with gr.Row(elem_classes=["wa-control-buttons", "wa-inline-buttons"]):
1073
+ reset_button = gr.Button(
1074
+ "Reset Episode",
1075
+ variant="primary",
1076
+ elem_classes=["wa-button-primary"],
1077
+ )
1078
+ random_seed_button = gr.Button(
1079
+ "Random Seed",
1080
+ variant="secondary",
1081
+ elem_classes=["wa-button-ghost"],
1082
+ )
1083
+
1084
+ preset_brief = gr.HTML(value=_preset_html(DifficultyPreset.EASY.value))
1085
+
1086
+ topbar = gr.HTML()
1087
+ banner = gr.HTML()
1088
+ planner = gr.HTML()
1089
+
1090
+ with gr.Row(elem_classes=["wa-main"]):
1091
+ with gr.Column(elem_classes=["wa-left-stack"]):
1092
+ with gr.Column(elem_classes=["wa-panel", "wa-card"]):
1093
+ gr.HTML(
1094
+ """
1095
+ <div class="wa-lane-header">
1096
+ <div>
1097
+ <div class="wa-lane-title">Dispatch Lane</div>
1098
+ <div class="wa-lane-copy">Inspect ready tasks, build a batch, and send work only when the action is legal.</div>
1099
+ </div>
1100
+ </div>
1101
+ """
1102
+ )
1103
+ selection_summary = gr.Markdown(elem_classes=["wa-hint"])
1104
+ ready_cards = gr.HTML()
1105
+ ready_selector = gr.CheckboxGroup(
1106
+ label="Build dispatch batch",
1107
+ info="Choose up to the number of currently free workers.",
1108
+ )
1109
+ with gr.Row(elem_classes=["wa-action-row"]):
1110
+ select_recommended_button = gr.Button(
1111
+ "Select Recommended",
1112
+ variant="secondary",
1113
+ interactive=False,
1114
+ elem_classes=["wa-button-secondary"],
1115
+ )
1116
+ dispatch_recommended_button = gr.Button(
1117
+ "Dispatch Recommended",
1118
+ variant="primary",
1119
+ interactive=False,
1120
+ elem_classes=["wa-button-primary"],
1121
+ )
1122
+ clear_selection_button = gr.Button(
1123
+ "Clear Selection",
1124
+ variant="secondary",
1125
+ interactive=False,
1126
+ elem_classes=["wa-button-ghost"],
1127
+ )
1128
+ dispatch_button = gr.Button(
1129
+ "Dispatch Selected",
1130
+ variant="primary",
1131
+ interactive=False,
1132
+ elem_classes=["wa-button-primary"],
1133
+ )
1134
+ wait_button = gr.Button(
1135
+ "Wait",
1136
+ variant="secondary",
1137
+ interactive=False,
1138
+ elem_classes=["wa-button-secondary"],
1139
+ )
1140
+
1141
+ with gr.Column(elem_classes=["wa-panel", "wa-card"]):
1142
+ gr.HTML(
1143
+ """
1144
+ <div class="wa-lane-header">
1145
+ <div>
1146
+ <div class="wa-lane-title">Selected Batch</div>
1147
+ <div class="wa-lane-copy">This preview updates as you pick tasks from the current ready queue.</div>
1148
+ </div>
1149
+ </div>
1150
+ """
1151
+ )
1152
+ selected_table = gr.Dataframe(
1153
+ headers=DETAIL_HEADERS,
1154
+ value=[],
1155
+ interactive=False,
1156
+ wrap=True,
1157
+ label="Dispatch preview",
1158
+ )
1159
+
1160
+ with gr.Column(elem_classes=["wa-right-stack"]):
1161
+ with gr.Column(elem_classes=["wa-panel", "wa-card"]):
1162
+ gr.HTML(
1163
+ """
1164
+ <div class="wa-lane-header">
1165
+ <div>
1166
+ <div class="wa-lane-title">Mission Control</div>
1167
+ <div class="wa-lane-copy">Use the recommendation as a guide, not a rule. The reward trace below tells you whether the choice paid off.</div>
1168
+ </div>
1169
+ </div>
1170
+ """
1171
+ )
1172
+ gr.Markdown(elem_classes=["wa-hint"], value="No episode yet.")
1173
+ decision_hint = gr.Markdown(elem_classes=["wa-hint"])
1174
+ running_cards = gr.HTML()
1175
+
1176
+ with gr.Column(elem_classes=["wa-plot-wrap", "wa-card"]):
1177
+ timeline_plot = gr.Plot(label="Workflow Timeline")
1178
+
1179
+ with gr.Accordion("Reward Breakdown", open=False, elem_classes=["wa-accordion"]):
1180
+ reward_markdown = gr.Markdown()
1181
+
1182
+ with gr.Accordion("Action History", open=False, elem_classes=["wa-accordion"]):
1183
+ history_markdown = gr.Markdown()
1184
+
1185
+ with gr.Row(elem_classes=["wa-footer-stack"]):
1186
+ with gr.Accordion("Completed Tasks", open=False, elem_classes=["wa-accordion"]):
1187
+ completed_table = gr.Dataframe(
1188
+ headers=DETAIL_HEADERS,
1189
+ value=[],
1190
+ interactive=False,
1191
+ wrap=True,
1192
+ label="Completed",
1193
+ )
1194
+
1195
+ with gr.Accordion("Blocked Tasks", open=False, elem_classes=["wa-accordion"]):
1196
+ blocked_table = gr.Dataframe(
1197
+ headers=DETAIL_HEADERS,
1198
+ value=[],
1199
+ interactive=False,
1200
+ wrap=True,
1201
+ label="Blocked",
1202
+ )
1203
+
1204
+ outputs = [
1205
+ session,
1206
+ preset_brief,
1207
+ topbar,
1208
+ banner,
1209
+ planner,
1210
+ decision_hint,
1211
+ selection_summary,
1212
+ ready_cards,
1213
+ ready_selector,
1214
+ dispatch_button,
1215
+ wait_button,
1216
+ select_recommended_button,
1217
+ dispatch_recommended_button,
1218
+ clear_selection_button,
1219
+ running_cards,
1220
+ timeline_plot,
1221
+ reward_markdown,
1222
+ history_markdown,
1223
+ selected_table,
1224
+ completed_table,
1225
+ blocked_table,
1226
+ ]
1227
+
1228
+ random_seed_button.click(_random_seed, outputs=[seed])
1229
+ preset.change(_preset_controls_update, inputs=[preset], outputs=[preset_brief, workers])
1230
+
1231
+ reset_button.click(
1232
+ _reset,
1233
+ inputs=[preset, seed, workers, session],
1234
+ outputs=outputs,
1235
+ )
1236
+ dispatch_button.click(
1237
+ _dispatch,
1238
+ inputs=[ready_selector, session],
1239
+ outputs=outputs,
1240
+ )
1241
+ dispatch_recommended_button.click(
1242
+ _dispatch_recommended,
1243
+ inputs=[session],
1244
+ outputs=outputs,
1245
+ )
1246
+ wait_button.click(
1247
+ _wait,
1248
+ inputs=[session],
1249
+ outputs=outputs,
1250
+ )
1251
+
1252
+ ready_selector.change(
1253
+ _update_selection,
1254
+ inputs=[ready_selector, session],
1255
+ outputs=[selection_summary, selected_table, dispatch_button],
1256
+ )
1257
+ select_recommended_button.click(
1258
+ _select_recommended,
1259
+ inputs=[session],
1260
+ outputs=[ready_selector, selection_summary, selected_table, dispatch_button],
1261
+ )
1262
+ clear_selection_button.click(
1263
+ _clear_selection,
1264
+ inputs=[session],
1265
+ outputs=[ready_selector, selection_summary, selected_table, dispatch_button],
1266
+ )
1267
+
1268
+ demo.load(lambda: _empty_updates(_blank_session()), outputs=outputs)
1269
+
1270
+ return demo
server/workflow_arena_environment.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """WorkflowArena event-driven workflow orchestration environment."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import math
12
+ import random
13
+ from typing import Any
14
+ from uuid import uuid4
15
+
16
+ from openenv.core.env_server.interfaces import Environment
17
+ from openenv.core.env_server.types import State
18
+
19
+ from workflow_arena.generator import generate_episode
20
+ from workflow_arena.models import (
21
+ DifficultyPreset,
22
+ EpisodeConfig,
23
+ FailureEventType,
24
+ ProgressSummary,
25
+ RewardBreakdown,
26
+ SuccessMetrics,
27
+ TaskStatus,
28
+ WorkflowArenaAction,
29
+ WorkflowArenaObservation,
30
+ WorkflowEnvStateSnapshot,
31
+ WorkflowEpisodeSpec,
32
+ WorkflowFailureEvent,
33
+ WorkflowTaskSpec,
34
+ WorkflowTaskView,
35
+ WorkflowActionType,
36
+ )
37
+ from workflow_arena.presets import get_preset_config
38
+
39
+
40
+ class WorkflowArenaEnvironment(Environment):
41
+ """Resource-constrained workflow scheduler with event-driven semantics."""
42
+
43
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
44
+ STEP_LIMIT_FLOOR: int = 32
45
+ STEP_LIMIT_MULTIPLIER: int = 8
46
+ INVALID_ACTION_PENALTY: float = -0.1
47
+ OVERCAPACITY_INVALID_ACTION_PENALTY: float = -0.25
48
+ AVOIDABLE_WAIT_PENALTY_PER_SLOT: float = -0.08
49
+ UNFINISHED_PRIORITY_PENALTY: float = -0.02
50
+ OVERDUE_PRIORITY_PENALTY_PER_TICK: float = -0.005
51
+ MAX_RECENT_FAILURE_EVENTS: int = 6
52
+
53
+ def __init__(self):
54
+ self._state = State(episode_id=str(uuid4()), step_count=0)
55
+ self._cumulative_reward = 0.0
56
+ self._max_episode_steps = self.STEP_LIMIT_FLOOR
57
+ self._config = EpisodeConfig(
58
+ preset=DifficultyPreset.EASY,
59
+ seed=0,
60
+ worker_count=2,
61
+ )
62
+ self._episode_spec: WorkflowEpisodeSpec | None = None
63
+ self._env_state: WorkflowEnvStateSnapshot | None = None
64
+ self._event_rng = random.Random(0)
65
+
66
+ def _require_episode(self) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
67
+ if self._episode_spec is None or self._env_state is None:
68
+ raise RuntimeError("Environment must be reset before use.")
69
+ return self._episode_spec, self._env_state
70
+
71
+ def _preset_config(self):
72
+ episode, _ = self._require_episode()
73
+ return episode.preset_config
74
+
75
+ def _task_map(self) -> dict[str, WorkflowTaskSpec]:
76
+ episode, _ = self._require_episode()
77
+ return {task.task_id: task for task in episode.tasks}
78
+
79
+ def _effective_worker_capacity(
80
+ self, env_state: WorkflowEnvStateSnapshot | None = None
81
+ ) -> int:
82
+ if env_state is None:
83
+ _, env_state = self._require_episode()
84
+ return max(0, self._config.worker_count - env_state.degraded_workers)
85
+
86
+ def _time_remaining(
87
+ self, env_state: WorkflowEnvStateSnapshot | None = None
88
+ ) -> int | None:
89
+ if env_state is None:
90
+ _, env_state = self._require_episode()
91
+ if env_state.time_budget is None:
92
+ return None
93
+ return max(0, env_state.time_budget - env_state.current_time)
94
+
95
+ def _terminal_score(self) -> float:
96
+ episode, env_state = self._require_episode()
97
+ if env_state.current_time <= 0:
98
+ return 0.0
99
+ lower_bound = self._lower_bound_makespan(episode)
100
+ score = lower_bound / max(lower_bound, env_state.current_time)
101
+ return round(score, 4)
102
+
103
+ def _benchmark_score(self) -> float:
104
+ makespan_score, deadline_score, utilization_score = self._grade_components(
105
+ include_terminal_makespan=True
106
+ )
107
+ return round(
108
+ (0.5 * makespan_score) + (0.3 * deadline_score) + (0.2 * utilization_score),
109
+ 4,
110
+ )
111
+
112
+ def _grade_components(
113
+ self, *, include_terminal_makespan: bool = False
114
+ ) -> tuple[float, float, float]:
115
+ episode, env_state = self._require_episode()
116
+ utilization = (
117
+ env_state.cumulative_busy_time
118
+ / (env_state.current_time * self._config.worker_count)
119
+ if env_state.current_time > 0
120
+ else 0.0
121
+ )
122
+ total_priority = sum(task.priority for task in episode.tasks) or 1
123
+ on_time_priority = 0
124
+ for task in episode.tasks:
125
+ end_time = env_state.task_end_times.get(task.task_id)
126
+ if end_time is None:
127
+ continue
128
+ if task.deadline is None or end_time <= task.deadline:
129
+ on_time_priority += task.priority
130
+ deadline_score = round(on_time_priority / total_priority, 4)
131
+ utilization_score = round(utilization, 4)
132
+ makespan_score = self._terminal_score() if include_terminal_makespan else 0.0
133
+ return makespan_score, deadline_score, utilization_score
134
+
135
+ def _unfinished_task_penalty(self, current_time: int) -> float:
136
+ episode, env_state = self._require_episode()
137
+ penalty = 0.0
138
+ for task in episode.tasks:
139
+ if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED:
140
+ continue
141
+ penalty += self.UNFINISHED_PRIORITY_PENALTY * task.priority
142
+ if task.deadline is not None and current_time > task.deadline:
143
+ penalty += (
144
+ self.OVERDUE_PRIORITY_PENALTY_PER_TICK
145
+ * task.priority
146
+ * (current_time - task.deadline)
147
+ )
148
+ return round(penalty, 4)
149
+
150
+ def _success_metrics(
151
+ self, *, benchmark_score_override: float | None = None
152
+ ) -> SuccessMetrics:
153
+ episode, env_state = self._require_episode()
154
+ unfinished_task_count = sum(
155
+ 1
156
+ for task in episode.tasks
157
+ if env_state.task_statuses[task.task_id] != TaskStatus.COMPLETED
158
+ )
159
+ deadline_miss_count = sum(
160
+ 1
161
+ for task in episode.tasks
162
+ if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED
163
+ and task.deadline is not None
164
+ and env_state.task_end_times.get(task.task_id, 0) > task.deadline
165
+ )
166
+ _, deadline_score, utilization_score = self._grade_components(
167
+ include_terminal_makespan=False
168
+ )
169
+ all_done = unfinished_task_count == 0
170
+ return SuccessMetrics(
171
+ makespan=env_state.current_time if all_done else None,
172
+ worker_utilization=utilization_score,
173
+ deadline_miss_count=deadline_miss_count,
174
+ unfinished_task_count=unfinished_task_count,
175
+ weighted_priority_completion=deadline_score,
176
+ benchmark_score=benchmark_score_override,
177
+ )
178
+
179
+ def _task_view(
180
+ self,
181
+ task: WorkflowTaskSpec,
182
+ status: TaskStatus,
183
+ *,
184
+ include_planner_hints: bool = True,
185
+ ) -> WorkflowTaskView:
186
+ _, env_state = self._require_episode()
187
+ return WorkflowTaskView(
188
+ task_id=task.task_id,
189
+ status=status,
190
+ duration=task.duration,
191
+ priority=task.priority,
192
+ dependencies=task.dependencies,
193
+ deadline=task.deadline,
194
+ criticality=task.criticality if include_planner_hints else None,
195
+ slack=float(task.slack) if include_planner_hints else None,
196
+ downstream_count=task.downstream_count if include_planner_hints else 0,
197
+ start_time=env_state.task_start_times.get(task.task_id),
198
+ end_time=(
199
+ env_state.task_end_times.get(task.task_id)
200
+ or env_state.task_assigned_finish_times.get(task.task_id)
201
+ ),
202
+ attempt_count=env_state.task_attempt_counts.get(task.task_id, 0),
203
+ )
204
+
205
+ def _task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
206
+ episode, env_state = self._require_episode()
207
+ return [
208
+ self._task_view(task, status, include_planner_hints=True)
209
+ for task in episode.tasks
210
+ if env_state.task_statuses[task.task_id] == status
211
+ ]
212
+
213
+ def debug_task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
214
+ return self._task_views_for_status(status)
215
+
216
+ def _set_recent_failure_events(
217
+ self,
218
+ env_state: WorkflowEnvStateSnapshot,
219
+ events: list[WorkflowFailureEvent],
220
+ ) -> None:
221
+ env_state.recent_failure_events = events[-self.MAX_RECENT_FAILURE_EVENTS :]
222
+
223
+ def _maybe_end_worker_outage(
224
+ self,
225
+ env_state: WorkflowEnvStateSnapshot,
226
+ events: list[WorkflowFailureEvent],
227
+ ) -> None:
228
+ if (
229
+ env_state.active_worker_outage_until is not None
230
+ and env_state.current_time >= env_state.active_worker_outage_until
231
+ ):
232
+ events.append(
233
+ WorkflowFailureEvent(
234
+ event_type=FailureEventType.WORKER_OUTAGE_END,
235
+ time=env_state.current_time,
236
+ worker_delta=1,
237
+ detail="Worker capacity restored.",
238
+ )
239
+ )
240
+ env_state.active_worker_outage_until = None
241
+ env_state.degraded_workers = 0
242
+
243
+ def _maybe_start_worker_outage(
244
+ self,
245
+ env_state: WorkflowEnvStateSnapshot,
246
+ events: list[WorkflowFailureEvent],
247
+ ) -> None:
248
+ preset_config = self._preset_config()
249
+ if self._config.preset != DifficultyPreset.HARD:
250
+ return
251
+ if env_state.active_worker_outage_until is not None:
252
+ return
253
+ if preset_config.worker_outage_rate <= 0.0:
254
+ return
255
+ if self._event_rng.random() >= preset_config.worker_outage_rate:
256
+ return
257
+
258
+ duration = self._event_rng.randint(
259
+ preset_config.worker_outage_duration_min,
260
+ preset_config.worker_outage_duration_max,
261
+ )
262
+ if duration <= 0:
263
+ return
264
+
265
+ env_state.degraded_workers = min(1, self._config.worker_count)
266
+ env_state.active_worker_outage_until = env_state.current_time + duration
267
+ events.append(
268
+ WorkflowFailureEvent(
269
+ event_type=FailureEventType.WORKER_OUTAGE_START,
270
+ time=env_state.current_time,
271
+ worker_delta=-env_state.degraded_workers,
272
+ duration=duration,
273
+ detail=f"Worker outage active until t={env_state.active_worker_outage_until}.",
274
+ )
275
+ )
276
+
277
+ def _should_retry_fail(self, task_id: str) -> bool:
278
+ preset_config = self._preset_config()
279
+ _, env_state = self._require_episode()
280
+ if self._config.preset != DifficultyPreset.HARD:
281
+ return False
282
+ if preset_config.task_retry_failure_rate <= 0.0:
283
+ return False
284
+ if env_state.task_attempt_counts.get(task_id, 0) >= preset_config.max_task_retries:
285
+ return False
286
+ return self._event_rng.random() < preset_config.task_retry_failure_rate
287
+
288
+ def _dispatch_potential(
289
+ self,
290
+ env_state: WorkflowEnvStateSnapshot,
291
+ task_map: dict[str, WorkflowTaskSpec],
292
+ ) -> tuple[float, float]:
293
+ if not env_state.running_task_ids:
294
+ return 0.0, 0.0
295
+
296
+ episode, _ = self._require_episode()
297
+ max_slack = max((task.slack for task in episode.tasks), default=0)
298
+ utilization_component = 0.06 * (
299
+ len(env_state.running_task_ids) / max(1, self._config.worker_count)
300
+ )
301
+ criticality_component = 0.0
302
+ for task_id in env_state.running_task_ids:
303
+ task = task_map[task_id]
304
+ slack_urgency = 1.0 if max_slack <= 0 else 1.0 - (task.slack / max_slack)
305
+ criticality_component += (0.6 * task.criticality) + (0.4 * slack_urgency)
306
+ criticality_component = 0.04 * (
307
+ criticality_component / max(1, self._config.worker_count)
308
+ )
309
+ return round(utilization_component, 4), round(criticality_component, 4)
310
+
311
+ def _base_observation(
312
+ self,
313
+ *,
314
+ reward: float,
315
+ breakdown: RewardBreakdown,
316
+ note: str,
317
+ done: bool,
318
+ benchmark_score_override: float | None = None,
319
+ ) -> WorkflowArenaObservation:
320
+ episode, env_state = self._require_episode()
321
+ ready_tasks = self._task_views_for_status(TaskStatus.READY)
322
+ running_tasks = self._task_views_for_status(TaskStatus.RUNNING)
323
+ completed_tasks = self._task_views_for_status(TaskStatus.COMPLETED)
324
+ blocked_tasks = self._task_views_for_status(TaskStatus.BLOCKED)
325
+ effective_workers = self._effective_worker_capacity(env_state)
326
+ return WorkflowArenaObservation(
327
+ done=done,
328
+ reward=reward,
329
+ config=self._config,
330
+ current_time=env_state.current_time,
331
+ total_workers=self._config.worker_count,
332
+ effective_workers=effective_workers,
333
+ degraded_workers=env_state.degraded_workers,
334
+ free_workers=max(0, effective_workers - len(running_tasks)),
335
+ time_budget=env_state.time_budget,
336
+ time_remaining=self._time_remaining(env_state),
337
+ progress=ProgressSummary(
338
+ total=len(episode.tasks),
339
+ blocked=len(blocked_tasks),
340
+ ready=len(ready_tasks),
341
+ running=len(running_tasks),
342
+ completed=len(completed_tasks),
343
+ ),
344
+ ready_tasks=ready_tasks,
345
+ running_tasks=running_tasks,
346
+ completed_tasks=completed_tasks,
347
+ blocked_tasks=blocked_tasks,
348
+ last_reward_breakdown=breakdown,
349
+ cumulative_reward=self._cumulative_reward,
350
+ success_metrics=self._success_metrics(
351
+ benchmark_score_override=benchmark_score_override
352
+ ),
353
+ note=note,
354
+ benchmark_score=benchmark_score_override,
355
+ recent_failure_events=env_state.recent_failure_events,
356
+ metadata={
357
+ "phase": "simulation_active",
358
+ "note": note,
359
+ "effective_workers": effective_workers,
360
+ "degraded_workers": env_state.degraded_workers,
361
+ "time_budget": env_state.time_budget,
362
+ "time_remaining": self._time_remaining(env_state),
363
+ "recent_failure_events": [
364
+ event.model_dump(mode="json") for event in env_state.recent_failure_events
365
+ ],
366
+ "episode_loop": [
367
+ "reset generates a seeded workflow DAG episode",
368
+ "dispatch(task_ids=[...]) starts ready tasks if workers are free",
369
+ "wait() advances simulated time to the next completion event",
370
+ "medium and hard episodes may end at a fixed time budget",
371
+ "hard mode may trigger outages and retry failures",
372
+ ],
373
+ },
374
+ )
375
+
376
+ def _lower_bound_makespan(self, episode: WorkflowEpisodeSpec) -> int:
377
+ total_work = sum(task.duration for task in episode.tasks)
378
+ work_bound = (total_work + self._config.worker_count - 1) // self._config.worker_count
379
+ path_bound = max(task.critical_path_length for task in episode.tasks)
380
+ return max(1, work_bound, path_bound)
381
+
382
+ def _termination_breakdown(
383
+ self,
384
+ *,
385
+ invalid_penalty: float = 0.0,
386
+ idle_penalty: float = 0.0,
387
+ terminal_makespan_score: float = 0.0,
388
+ unfinished_task_penalty: float = 0.0,
389
+ ) -> RewardBreakdown:
390
+ return RewardBreakdown(
391
+ invalid_action_penalty=round(invalid_penalty, 4),
392
+ idle_penalty=round(idle_penalty, 4),
393
+ terminal_makespan_score=round(terminal_makespan_score, 4),
394
+ unfinished_task_penalty=round(unfinished_task_penalty, 4),
395
+ )
396
+
397
+ def _terminate_episode(
398
+ self,
399
+ *,
400
+ note: str,
401
+ breakdown: RewardBreakdown,
402
+ reward: float,
403
+ reason: str,
404
+ benchmark_score: float | None = None,
405
+ ) -> WorkflowArenaObservation:
406
+ if benchmark_score is None:
407
+ benchmark_score = self._benchmark_score()
408
+ self._cumulative_reward += reward
409
+ observation = self._base_observation(
410
+ reward=reward,
411
+ breakdown=breakdown,
412
+ note=note,
413
+ done=True,
414
+ benchmark_score_override=benchmark_score,
415
+ )
416
+ observation.termination_reason = reason
417
+ observation.benchmark_score = benchmark_score
418
+ observation.metadata["termination_reason"] = reason
419
+ observation.metadata["benchmark_score"] = benchmark_score
420
+ return observation
421
+
422
+ def _step_limit_reached(self) -> bool:
423
+ return self._state.step_count >= self._max_episode_steps
424
+
425
+ def _maybe_terminate_for_limits(self) -> WorkflowArenaObservation | None:
426
+ if not self._step_limit_reached():
427
+ return None
428
+ _, env_state = self._require_episode()
429
+ unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
430
+ terminal_score = self._terminal_score()
431
+ breakdown = self._termination_breakdown(
432
+ terminal_makespan_score=terminal_score,
433
+ unfinished_task_penalty=unfinished_penalty,
434
+ )
435
+ reward = round(-1.0 + unfinished_penalty + terminal_score, 4)
436
+ return self._terminate_episode(
437
+ note="Episode terminated after hitting the safety step limit.",
438
+ breakdown=breakdown,
439
+ reward=reward,
440
+ reason="step_limit",
441
+ )
442
+
443
+ def _apply_invalid(
444
+ self,
445
+ message: str,
446
+ *,
447
+ penalty: float | None = None,
448
+ ) -> WorkflowArenaObservation:
449
+ _, env_state = self._require_episode()
450
+ applied_penalty = (
451
+ self.INVALID_ACTION_PENALTY if penalty is None else float(penalty)
452
+ )
453
+ breakdown = RewardBreakdown(invalid_action_penalty=round(applied_penalty, 4))
454
+ self._cumulative_reward += breakdown.invalid_action_penalty
455
+ self._set_recent_failure_events(env_state, [])
456
+ observation = self._base_observation(
457
+ reward=breakdown.invalid_action_penalty,
458
+ breakdown=breakdown,
459
+ note="Invalid action.",
460
+ done=False,
461
+ )
462
+ observation.validation_error = message
463
+ observation.metadata["validation_error"] = message
464
+ return observation
465
+
466
+ def _transition_unlocks(self, completed_task_ids: list[str]) -> list[str]:
467
+ episode, env_state = self._require_episode()
468
+ task_map = {task.task_id: task for task in episode.tasks}
469
+ unlocked: list[str] = []
470
+ for task_id in completed_task_ids:
471
+ for dependent_id in task_map[task_id].dependents:
472
+ env_state.task_remaining_dependencies[dependent_id] -= 1
473
+ if env_state.task_remaining_dependencies[dependent_id] == 0:
474
+ env_state.task_statuses[dependent_id] = TaskStatus.READY
475
+ if dependent_id not in env_state.ready_task_ids:
476
+ env_state.ready_task_ids.append(dependent_id)
477
+ if dependent_id in env_state.blocked_task_ids:
478
+ env_state.blocked_task_ids.remove(dependent_id)
479
+ unlocked.append(dependent_id)
480
+ env_state.ready_task_ids.sort()
481
+ env_state.blocked_task_ids.sort()
482
+ return unlocked
483
+
484
+ def reset(
485
+ self,
486
+ seed: int | None = None,
487
+ episode_id: str | None = None,
488
+ **kwargs: Any,
489
+ ) -> WorkflowArenaObservation:
490
+ """Generate a seeded workflow DAG episode."""
491
+
492
+ preset_raw = kwargs.pop("preset", DifficultyPreset.EASY)
493
+ worker_count_raw = kwargs.pop("worker_count", None)
494
+ del kwargs
495
+ preset = (
496
+ preset_raw
497
+ if isinstance(preset_raw, DifficultyPreset)
498
+ else DifficultyPreset(str(preset_raw))
499
+ )
500
+ preset_config = get_preset_config(preset)
501
+ chosen_seed = 0 if seed is None else seed
502
+ chosen_worker_count = (
503
+ preset_config.worker_count
504
+ if worker_count_raw is None
505
+ else int(worker_count_raw)
506
+ )
507
+ chosen_episode_id = str(uuid4()) if episode_id is None else episode_id
508
+ self._state = State(episode_id=chosen_episode_id, step_count=0)
509
+ self._cumulative_reward = 0.0
510
+ self._config = EpisodeConfig(
511
+ preset=preset,
512
+ seed=chosen_seed,
513
+ worker_count=chosen_worker_count,
514
+ )
515
+ self._event_rng = random.Random(
516
+ (chosen_seed + 1) * 1009
517
+ + (chosen_worker_count * 131)
518
+ + (list(DifficultyPreset).index(preset) + 1)
519
+ )
520
+ self._episode_spec, self._env_state = generate_episode(self._config)
521
+ self._max_episode_steps = max(
522
+ self.STEP_LIMIT_FLOOR,
523
+ len(self._episode_spec.tasks) * self.STEP_LIMIT_MULTIPLIER,
524
+ )
525
+ self._env_state.episode_id = chosen_episode_id
526
+ lower_bound = self._lower_bound_makespan(self._episode_spec)
527
+ if preset_config.time_budget_multiplier is not None:
528
+ self._env_state.time_budget = int(
529
+ math.ceil(lower_bound * preset_config.time_budget_multiplier)
530
+ )
531
+ self._set_recent_failure_events(self._env_state, [])
532
+
533
+ note = "Workflow episode generated. Dispatch ready tasks or wait for completions."
534
+ if self._env_state.time_budget is not None:
535
+ note = (
536
+ f"Workflow episode generated. Finish as much as possible before "
537
+ f"t={self._env_state.time_budget}."
538
+ )
539
+ if preset == DifficultyPreset.HARD:
540
+ note += " Hard mode may trigger worker outages and retry failures."
541
+ return self._base_observation(
542
+ reward=0.0,
543
+ breakdown=RewardBreakdown(),
544
+ note=note,
545
+ done=False,
546
+ )
547
+
548
+ def _wait_note(
549
+ self,
550
+ *,
551
+ completed_now: list[str],
552
+ failed_now: list[str],
553
+ unlocked: list[str],
554
+ recent_events: list[WorkflowFailureEvent],
555
+ time_budget_hit: bool = False,
556
+ ) -> str:
557
+ chunks: list[str] = []
558
+ if time_budget_hit:
559
+ chunks.append("Time budget exhausted before the next completion event.")
560
+ elif completed_now:
561
+ chunks.append(f"Completed: {', '.join(completed_now)}.")
562
+ else:
563
+ chunks.append("Advanced to next completion event.")
564
+ if failed_now:
565
+ chunks.append(f"Retry required: {', '.join(failed_now)}.")
566
+ if unlocked:
567
+ chunks.append(f"Unlocked: {', '.join(unlocked)}.")
568
+ for event in recent_events:
569
+ if event.event_type == FailureEventType.WORKER_OUTAGE_START:
570
+ chunks.append(event.detail)
571
+ elif event.event_type == FailureEventType.WORKER_OUTAGE_END:
572
+ chunks.append("Worker capacity restored.")
573
+ return " ".join(chunks)
574
+
575
+ def step(
576
+ self,
577
+ action: WorkflowArenaAction,
578
+ timeout_s: float | None = None,
579
+ **kwargs: Any,
580
+ ) -> WorkflowArenaObservation:
581
+ """Apply a dispatch or wait action using event-driven semantics."""
582
+
583
+ del timeout_s, kwargs
584
+ episode, env_state = self._require_episode()
585
+ task_map = {task.task_id: task for task in episode.tasks}
586
+ self._state.step_count += 1
587
+ self._set_recent_failure_events(env_state, [])
588
+
589
+ limit_termination = self._maybe_terminate_for_limits()
590
+ if limit_termination is not None:
591
+ return limit_termination
592
+
593
+ if action.action_type == WorkflowActionType.WAIT and action.task_ids:
594
+ return self._apply_invalid("wait() must not include task_ids.")
595
+
596
+ if action.action_type == WorkflowActionType.DISPATCH:
597
+ if not action.task_ids:
598
+ return self._apply_invalid(
599
+ "dispatch(task_ids=[...]) requires at least one task id."
600
+ )
601
+ if len(set(action.task_ids)) != len(action.task_ids):
602
+ return self._apply_invalid(
603
+ "dispatch(task_ids=[...]) must not contain duplicate task ids."
604
+ )
605
+
606
+ free_workers = self._effective_worker_capacity(env_state) - len(
607
+ env_state.running_task_ids
608
+ )
609
+ if len(action.task_ids) > max(0, free_workers):
610
+ return self._apply_invalid(
611
+ "dispatch(task_ids=[...]) cannot exceed available worker capacity.",
612
+ penalty=self.OVERCAPACITY_INVALID_ACTION_PENALTY,
613
+ )
614
+
615
+ unknown_tasks = [task_id for task_id in action.task_ids if task_id not in task_map]
616
+ if unknown_tasks:
617
+ return self._apply_invalid(f"Unknown task ids: {unknown_tasks}.")
618
+
619
+ not_ready = [
620
+ task_id
621
+ for task_id in action.task_ids
622
+ if env_state.task_statuses[task_id] != TaskStatus.READY
623
+ ]
624
+ if not_ready:
625
+ return self._apply_invalid(
626
+ f"Only ready tasks can be dispatched: {not_ready}."
627
+ )
628
+
629
+ prev_utilization_potential, prev_criticality_potential = self._dispatch_potential(
630
+ env_state, task_map
631
+ )
632
+ for task_id in action.task_ids:
633
+ task = task_map[task_id]
634
+ env_state.task_statuses[task_id] = TaskStatus.RUNNING
635
+ env_state.task_start_times[task_id] = env_state.current_time
636
+ env_state.task_assigned_finish_times[task_id] = (
637
+ env_state.current_time + task.duration
638
+ )
639
+ env_state.running_task_ids.append(task_id)
640
+ env_state.ready_task_ids.remove(task_id)
641
+ env_state.running_task_ids.sort()
642
+ next_utilization_potential, next_criticality_potential = self._dispatch_potential(
643
+ env_state, task_map
644
+ )
645
+ breakdown = RewardBreakdown(
646
+ utilization_reward=round(
647
+ next_utilization_potential - prev_utilization_potential, 4
648
+ ),
649
+ criticality_reward=round(
650
+ next_criticality_potential - prev_criticality_potential, 4
651
+ ),
652
+ )
653
+ reward = round(
654
+ breakdown.utilization_reward + breakdown.criticality_reward,
655
+ 4,
656
+ )
657
+ self._cumulative_reward += reward
658
+ observation = self._base_observation(
659
+ reward=reward,
660
+ breakdown=breakdown,
661
+ note="Tasks dispatched. Use wait() to advance to the next completion event.",
662
+ done=False,
663
+ )
664
+ observation.received_action = action.model_dump(mode="json")
665
+ observation.metadata["received_action"] = action.model_dump(mode="json")
666
+ return observation
667
+
668
+ if not env_state.running_task_ids:
669
+ return self._apply_invalid("wait() requires at least one running task.")
670
+
671
+ recent_events: list[WorkflowFailureEvent] = []
672
+ avoidable_wait_penalty = 0.0
673
+ if env_state.ready_task_ids:
674
+ free_workers = self._effective_worker_capacity(env_state) - len(
675
+ env_state.running_task_ids
676
+ )
677
+ if free_workers > 0:
678
+ avoidable_wait_penalty = self.AVOIDABLE_WAIT_PENALTY_PER_SLOT * min(
679
+ free_workers,
680
+ len(env_state.ready_task_ids),
681
+ )
682
+
683
+ self._maybe_start_worker_outage(env_state, recent_events)
684
+
685
+ next_completion_time = min(
686
+ env_state.task_assigned_finish_times[task_id]
687
+ for task_id in env_state.running_task_ids
688
+ )
689
+ target_time = next_completion_time
690
+ budget_hit_before_completion = False
691
+ if env_state.time_budget is not None and env_state.time_budget < next_completion_time:
692
+ target_time = env_state.time_budget
693
+ budget_hit_before_completion = True
694
+
695
+ elapsed = target_time - env_state.current_time
696
+ env_state.cumulative_busy_time += elapsed * len(env_state.running_task_ids)
697
+ env_state.current_time = target_time
698
+ self._maybe_end_worker_outage(env_state, recent_events)
699
+
700
+ if budget_hit_before_completion:
701
+ unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
702
+ terminal_score = self._terminal_score()
703
+ breakdown = RewardBreakdown(
704
+ idle_penalty=round(avoidable_wait_penalty, 4),
705
+ terminal_makespan_score=round(terminal_score, 4),
706
+ unfinished_task_penalty=round(unfinished_penalty, 4),
707
+ )
708
+ reward = round(
709
+ breakdown.idle_penalty
710
+ + breakdown.terminal_makespan_score
711
+ + breakdown.unfinished_task_penalty,
712
+ 4,
713
+ )
714
+ self._set_recent_failure_events(env_state, recent_events)
715
+ note = self._wait_note(
716
+ completed_now=[],
717
+ failed_now=[],
718
+ unlocked=[],
719
+ recent_events=recent_events,
720
+ time_budget_hit=True,
721
+ )
722
+ observation = self._terminate_episode(
723
+ note=note,
724
+ breakdown=breakdown,
725
+ reward=reward,
726
+ reason="time_budget",
727
+ )
728
+ observation.received_action = action.model_dump(mode="json")
729
+ observation.metadata["received_action"] = action.model_dump(mode="json")
730
+ return observation
731
+
732
+ completed_candidates = sorted(
733
+ [
734
+ task_id
735
+ for task_id in env_state.running_task_ids
736
+ if env_state.task_assigned_finish_times[task_id] == next_completion_time
737
+ ]
738
+ )
739
+ completed_now: list[str] = []
740
+ failed_now: list[str] = []
741
+ for task_id in completed_candidates:
742
+ env_state.running_task_ids.remove(task_id)
743
+ del env_state.task_assigned_finish_times[task_id]
744
+ if self._should_retry_fail(task_id):
745
+ env_state.task_attempt_counts[task_id] += 1
746
+ env_state.task_statuses[task_id] = TaskStatus.READY
747
+ env_state.task_start_times.pop(task_id, None)
748
+ env_state.task_end_times.pop(task_id, None)
749
+ if task_id not in env_state.ready_task_ids:
750
+ env_state.ready_task_ids.append(task_id)
751
+ failed_now.append(task_id)
752
+ recent_events.append(
753
+ WorkflowFailureEvent(
754
+ event_type=FailureEventType.TASK_RETRY_FAILURE,
755
+ time=next_completion_time,
756
+ task_id=task_id,
757
+ detail=f"{task_id} failed and returned to ready.",
758
+ )
759
+ )
760
+ else:
761
+ env_state.task_statuses[task_id] = TaskStatus.COMPLETED
762
+ env_state.task_end_times[task_id] = next_completion_time
763
+ env_state.completed_task_ids.append(task_id)
764
+ completed_now.append(task_id)
765
+
766
+ env_state.completed_task_ids.sort()
767
+ env_state.ready_task_ids.sort()
768
+ unlocked = self._transition_unlocks(completed_now)
769
+
770
+ completion_reward = sum(
771
+ 0.04 + 0.01 * task_map[task_id].priority for task_id in completed_now
772
+ )
773
+ deadline_reward = 0.0
774
+ criticality_reward = 0.0
775
+ for task_id in completed_now:
776
+ task = task_map[task_id]
777
+ if task.deadline is not None:
778
+ lateness = next_completion_time - task.deadline
779
+ deadline_reward += 0.05 if lateness <= 0 else -0.02 * lateness
780
+ criticality_reward += 0.03 * task.criticality
781
+
782
+ utilization_reward = 0.06 * (
783
+ elapsed
784
+ * (len(completed_candidates) + len(env_state.running_task_ids))
785
+ / max(1, self._config.worker_count)
786
+ )
787
+ idle_penalty = 0.0
788
+ if not env_state.running_task_ids and env_state.ready_task_ids:
789
+ idle_penalty = -0.03 * len(env_state.ready_task_ids)
790
+
791
+ done = len(env_state.completed_task_ids) == len(episode.tasks)
792
+ breakdown = RewardBreakdown(
793
+ completion_reward=round(completion_reward, 4),
794
+ utilization_reward=round(utilization_reward, 4),
795
+ deadline_reward=round(deadline_reward, 4),
796
+ criticality_reward=round(criticality_reward, 4),
797
+ idle_penalty=round(idle_penalty + avoidable_wait_penalty, 4),
798
+ terminal_makespan_score=round(self._terminal_score() if done else 0.0, 4),
799
+ )
800
+ reward = round(
801
+ breakdown.completion_reward
802
+ + breakdown.utilization_reward
803
+ + breakdown.deadline_reward
804
+ + breakdown.criticality_reward
805
+ + breakdown.idle_penalty
806
+ + breakdown.terminal_makespan_score,
807
+ 4,
808
+ )
809
+
810
+ budget_exhausted_now = (
811
+ not done
812
+ and env_state.time_budget is not None
813
+ and env_state.current_time >= env_state.time_budget
814
+ )
815
+ if budget_exhausted_now:
816
+ unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
817
+ breakdown.unfinished_task_penalty = round(unfinished_penalty, 4)
818
+ breakdown.terminal_makespan_score = round(self._terminal_score(), 4)
819
+ reward = round(
820
+ reward
821
+ + breakdown.unfinished_task_penalty
822
+ + breakdown.terminal_makespan_score,
823
+ 4,
824
+ )
825
+ self._set_recent_failure_events(env_state, recent_events)
826
+ note = self._wait_note(
827
+ completed_now=completed_now,
828
+ failed_now=failed_now,
829
+ unlocked=unlocked,
830
+ recent_events=recent_events,
831
+ time_budget_hit=True,
832
+ )
833
+ observation = self._terminate_episode(
834
+ note=note,
835
+ breakdown=breakdown,
836
+ reward=reward,
837
+ reason="time_budget",
838
+ )
839
+ observation.received_action = action.model_dump(mode="json")
840
+ observation.metadata["received_action"] = action.model_dump(mode="json")
841
+ observation.metadata["completed_now"] = completed_now
842
+ observation.metadata["unlocked_now"] = unlocked
843
+ observation.metadata["failed_now"] = failed_now
844
+ return observation
845
+
846
+ self._cumulative_reward += reward
847
+ self._set_recent_failure_events(env_state, recent_events)
848
+ observation = self._base_observation(
849
+ reward=reward,
850
+ breakdown=breakdown,
851
+ note=self._wait_note(
852
+ completed_now=completed_now,
853
+ failed_now=failed_now,
854
+ unlocked=unlocked,
855
+ recent_events=recent_events,
856
+ ),
857
+ done=done,
858
+ benchmark_score_override=self._benchmark_score() if done else None,
859
+ )
860
+ observation.received_action = action.model_dump(mode="json")
861
+ observation.metadata["received_action"] = action.model_dump(mode="json")
862
+ observation.metadata["completed_now"] = completed_now
863
+ observation.metadata["unlocked_now"] = unlocked
864
+ observation.metadata["failed_now"] = failed_now
865
+ if done:
866
+ observation.benchmark_score = observation.success_metrics.benchmark_score
867
+ return observation
868
+
869
+ @property
870
+ def state(self) -> State:
871
+ """Expose generic OpenEnv state metadata."""
872
+
873
+ return self._state
uv.lock ADDED
The diff for this file is too large to render. See raw diff