Spaces:
Running
Running
Commit ·
aea0016
verified ·
0
Parent(s):
init: WorkFlowArena
Browse files- .gitignore +10 -0
- __init__.py +34 -0
- client.py +52 -0
- generator.py +185 -0
- models.py +483 -0
- openenv.yaml +7 -0
- presets.py +96 -0
- pyproject.toml +47 -0
- server/Dockerfile +80 -0
- server/__init__.py +11 -0
- server/app.py +90 -0
- server/requirements.txt +6 -0
- server/ui.py +1270 -0
- server/workflow_arena_environment.py +873 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""WorkflowArena package exports."""
|
| 8 |
+
|
| 9 |
+
from workflow_arena.client import WorkflowArenaEnv
|
| 10 |
+
from workflow_arena.generator import generate_episode
|
| 11 |
+
from workflow_arena.models import (
|
| 12 |
+
DifficultyPreset,
|
| 13 |
+
EpisodeConfig,
|
| 14 |
+
TaskStatus,
|
| 15 |
+
WorkflowActionType,
|
| 16 |
+
WorkflowArenaAction,
|
| 17 |
+
WorkflowArenaObservation,
|
| 18 |
+
)
|
| 19 |
+
from workflow_arena.presets import PRESET_CONFIGS, get_preset_config
|
| 20 |
+
from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"DifficultyPreset",
|
| 24 |
+
"EpisodeConfig",
|
| 25 |
+
"PRESET_CONFIGS",
|
| 26 |
+
"TaskStatus",
|
| 27 |
+
"WorkflowActionType",
|
| 28 |
+
"WorkflowArenaAction",
|
| 29 |
+
"WorkflowArenaEnv",
|
| 30 |
+
"WorkflowArenaEnvironment",
|
| 31 |
+
"WorkflowArenaObservation",
|
| 32 |
+
"generate_episode",
|
| 33 |
+
"get_preset_config",
|
| 34 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""WorkflowArena client."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
from workflow_arena.models import WorkflowArenaAction, WorkflowArenaObservation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class WorkflowArenaEnv(
|
| 19 |
+
EnvClient[WorkflowArenaAction, WorkflowArenaObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""Typed client for the WorkflowArena server."""
|
| 22 |
+
|
| 23 |
+
def _step_payload(self, action: WorkflowArenaAction) -> Dict:
|
| 24 |
+
"""Convert a typed action into the JSON payload expected by the server."""
|
| 25 |
+
|
| 26 |
+
return action.model_dump(mode="json")
|
| 27 |
+
|
| 28 |
+
def _parse_result(self, payload: Dict) -> StepResult[WorkflowArenaObservation]:
|
| 29 |
+
"""Parse a server response into a typed observation wrapper."""
|
| 30 |
+
|
| 31 |
+
obs_data = payload.get("observation", {})
|
| 32 |
+
observation = WorkflowArenaObservation.model_validate(
|
| 33 |
+
{
|
| 34 |
+
**obs_data,
|
| 35 |
+
"done": payload.get("done", False),
|
| 36 |
+
"reward": payload.get("reward"),
|
| 37 |
+
}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return StepResult(
|
| 41 |
+
observation=observation,
|
| 42 |
+
reward=payload.get("reward"),
|
| 43 |
+
done=payload.get("done", False),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 47 |
+
"""Parse server state response into the generic OpenEnv state type."""
|
| 48 |
+
|
| 49 |
+
return State(
|
| 50 |
+
episode_id=payload.get("episode_id"),
|
| 51 |
+
step_count=payload.get("step_count", 0),
|
| 52 |
+
)
|
generator.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Seeded workflow DAG generator and derived static metrics for WorkflowArena."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
from workflow_arena.models import (
|
| 14 |
+
EpisodeConfig,
|
| 15 |
+
TaskStatus,
|
| 16 |
+
WorkflowEnvStateSnapshot,
|
| 17 |
+
WorkflowEpisodeSpec,
|
| 18 |
+
WorkflowTaskSpec,
|
| 19 |
+
)
|
| 20 |
+
from workflow_arena.presets import get_preset_config
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _task_id(index: int) -> str:
|
| 24 |
+
return f"task_{index:02d}"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _compute_earliest_start(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
|
| 28 |
+
task = task_map[task_id]
|
| 29 |
+
if not task.dependencies:
|
| 30 |
+
return 0
|
| 31 |
+
return max(
|
| 32 |
+
_compute_earliest_start(task_map, dep_id) + task_map[dep_id].duration
|
| 33 |
+
for dep_id in task.dependencies
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _compute_critical_path(task_map: dict[str, WorkflowTaskSpec], task_id: str) -> int:
|
| 38 |
+
task = task_map[task_id]
|
| 39 |
+
if not task.dependents:
|
| 40 |
+
return task.duration
|
| 41 |
+
return task.duration + max(
|
| 42 |
+
_compute_critical_path(task_map, child_id) for child_id in task.dependents
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _compute_downstream_count(
|
| 47 |
+
task_map: dict[str, WorkflowTaskSpec], task_id: str, seen: set[str] | None = None
|
| 48 |
+
) -> int:
|
| 49 |
+
task = task_map[task_id]
|
| 50 |
+
local_seen = set() if seen is None else seen
|
| 51 |
+
count = 0
|
| 52 |
+
for child_id in task.dependents:
|
| 53 |
+
if child_id in local_seen:
|
| 54 |
+
continue
|
| 55 |
+
local_seen.add(child_id)
|
| 56 |
+
count += 1 + _compute_downstream_count(task_map, child_id, local_seen)
|
| 57 |
+
return count
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _estimate_deadline(
|
| 61 |
+
task: WorkflowTaskSpec,
|
| 62 |
+
workflow_critical_path: int,
|
| 63 |
+
rng: random.Random,
|
| 64 |
+
tightness: float,
|
| 65 |
+
) -> int:
|
| 66 |
+
slack_allowance = max(1, int(round((workflow_critical_path - task.earliest_start) * (1.15 - tightness))))
|
| 67 |
+
jitter = rng.randint(0, max(1, task.duration // 2))
|
| 68 |
+
return task.earliest_start + task.duration + slack_allowance + jitter
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def generate_episode(
|
| 72 |
+
config: EpisodeConfig,
|
| 73 |
+
) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
|
| 74 |
+
"""Generate a deterministic workflow episode from a preset and seed."""
|
| 75 |
+
|
| 76 |
+
preset_config = get_preset_config(config.preset)
|
| 77 |
+
worker_count = config.worker_count or preset_config.worker_count
|
| 78 |
+
resolved_config = config.model_copy(update={"worker_count": worker_count})
|
| 79 |
+
rng = random.Random(resolved_config.seed)
|
| 80 |
+
task_count = rng.randint(preset_config.min_tasks, preset_config.max_tasks)
|
| 81 |
+
|
| 82 |
+
dependency_map: dict[str, list[str]] = {}
|
| 83 |
+
dependent_map: dict[str, list[str]] = {}
|
| 84 |
+
task_ids = [_task_id(index + 1) for index in range(task_count)]
|
| 85 |
+
|
| 86 |
+
for index, task_id in enumerate(task_ids):
|
| 87 |
+
candidates = task_ids[:index]
|
| 88 |
+
dependencies: list[str] = []
|
| 89 |
+
if candidates:
|
| 90 |
+
for candidate in candidates:
|
| 91 |
+
if rng.random() < preset_config.edge_probability:
|
| 92 |
+
dependencies.append(candidate)
|
| 93 |
+
if not dependencies and index > 0 and rng.random() < 0.6:
|
| 94 |
+
dependencies.append(rng.choice(candidates))
|
| 95 |
+
dependency_map[task_id] = sorted(set(dependencies), key=task_ids.index)
|
| 96 |
+
dependent_map[task_id] = []
|
| 97 |
+
|
| 98 |
+
for task_id, dependencies in dependency_map.items():
|
| 99 |
+
for dependency in dependencies:
|
| 100 |
+
dependent_map[dependency].append(task_id)
|
| 101 |
+
|
| 102 |
+
tasks = [
|
| 103 |
+
WorkflowTaskSpec(
|
| 104 |
+
task_id=task_id,
|
| 105 |
+
duration=rng.randint(preset_config.duration_min, preset_config.duration_max),
|
| 106 |
+
priority=rng.randint(preset_config.priority_min, preset_config.priority_max),
|
| 107 |
+
dependencies=dependency_map[task_id],
|
| 108 |
+
dependents=sorted(dependent_map[task_id], key=task_ids.index),
|
| 109 |
+
deadline=None,
|
| 110 |
+
)
|
| 111 |
+
for task_id in task_ids
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
task_map = {task.task_id: task for task in tasks}
|
| 115 |
+
|
| 116 |
+
workflow_critical_path = 0
|
| 117 |
+
for task in tasks:
|
| 118 |
+
task.earliest_start = _compute_earliest_start(task_map, task.task_id)
|
| 119 |
+
task.critical_path_length = _compute_critical_path(task_map, task.task_id)
|
| 120 |
+
task.downstream_count = _compute_downstream_count(task_map, task.task_id)
|
| 121 |
+
workflow_critical_path = max(
|
| 122 |
+
workflow_critical_path, task.earliest_start + task.duration
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
workflow_critical_path = max(
|
| 126 |
+
workflow_critical_path,
|
| 127 |
+
max(task.critical_path_length for task in tasks),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
max_downstream = max(task.downstream_count for task in tasks) if tasks else 1
|
| 131 |
+
max_critical_path = max(task.critical_path_length for task in tasks) if tasks else 1
|
| 132 |
+
|
| 133 |
+
for task in tasks:
|
| 134 |
+
latest_start = max(
|
| 135 |
+
task.earliest_start, workflow_critical_path - task.critical_path_length
|
| 136 |
+
)
|
| 137 |
+
task.slack = max(0, latest_start - task.earliest_start)
|
| 138 |
+
task.criticality = round(
|
| 139 |
+
0.7 * (task.critical_path_length / max_critical_path)
|
| 140 |
+
+ 0.3 * (task.downstream_count / max(1, max_downstream)),
|
| 141 |
+
4,
|
| 142 |
+
)
|
| 143 |
+
task.deadline = _estimate_deadline(
|
| 144 |
+
task=task,
|
| 145 |
+
workflow_critical_path=workflow_critical_path,
|
| 146 |
+
rng=rng,
|
| 147 |
+
tightness=preset_config.deadline_tightness,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
episode = WorkflowEpisodeSpec(
|
| 151 |
+
config=resolved_config,
|
| 152 |
+
preset_config=preset_config,
|
| 153 |
+
tasks=tasks,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
ready_task_ids = [task.task_id for task in tasks if not task.dependencies]
|
| 157 |
+
blocked_task_ids = [task.task_id for task in tasks if task.dependencies]
|
| 158 |
+
|
| 159 |
+
state = WorkflowEnvStateSnapshot(
|
| 160 |
+
episode_id=f"seed-{resolved_config.seed}",
|
| 161 |
+
current_time=0,
|
| 162 |
+
task_statuses={
|
| 163 |
+
task.task_id: (
|
| 164 |
+
TaskStatus.READY if not task.dependencies else TaskStatus.BLOCKED
|
| 165 |
+
)
|
| 166 |
+
for task in tasks
|
| 167 |
+
},
|
| 168 |
+
running_task_ids=[],
|
| 169 |
+
completed_task_ids=[],
|
| 170 |
+
ready_task_ids=ready_task_ids,
|
| 171 |
+
blocked_task_ids=blocked_task_ids,
|
| 172 |
+
task_start_times={},
|
| 173 |
+
task_end_times={},
|
| 174 |
+
task_remaining_dependencies={
|
| 175 |
+
task.task_id: len(task.dependencies) for task in tasks
|
| 176 |
+
},
|
| 177 |
+
task_assigned_finish_times={},
|
| 178 |
+
task_attempt_counts={task.task_id: 0 for task in tasks},
|
| 179 |
+
cumulative_busy_time=0,
|
| 180 |
+
time_budget=None,
|
| 181 |
+
degraded_workers=0,
|
| 182 |
+
active_worker_outage_until=None,
|
| 183 |
+
recent_failure_events=[],
|
| 184 |
+
)
|
| 185 |
+
return episode, state
|
models.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Typed models for WorkflowArena.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from enum import Enum
|
| 14 |
+
|
| 15 |
+
from openenv.core.env_server.types import Action, Observation
|
| 16 |
+
from pydantic import BaseModel, Field
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaskStatus(str, Enum):
|
| 20 |
+
"""Allowed lifecycle states for a workflow task."""
|
| 21 |
+
|
| 22 |
+
BLOCKED = "blocked"
|
| 23 |
+
READY = "ready"
|
| 24 |
+
RUNNING = "running"
|
| 25 |
+
COMPLETED = "completed"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DifficultyPreset(str, Enum):
|
| 29 |
+
"""Initial task presets required by the hackathon."""
|
| 30 |
+
|
| 31 |
+
EASY = "easy"
|
| 32 |
+
MEDIUM = "medium"
|
| 33 |
+
HARD = "hard"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class WorkflowActionType(str, Enum):
|
| 37 |
+
"""Explicit action space for the scheduler agent."""
|
| 38 |
+
|
| 39 |
+
DISPATCH = "dispatch"
|
| 40 |
+
WAIT = "wait"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RewardBreakdown(BaseModel):
|
| 44 |
+
"""Named reward channels for shaped feedback."""
|
| 45 |
+
|
| 46 |
+
completion_reward: float = Field(
|
| 47 |
+
default=0.0, description="Reward for completing tasks."
|
| 48 |
+
)
|
| 49 |
+
utilization_reward: float = Field(
|
| 50 |
+
default=0.0, description="Reward for keeping workers busy."
|
| 51 |
+
)
|
| 52 |
+
deadline_reward: float = Field(
|
| 53 |
+
default=0.0, description="Reward or penalty tied to deadlines."
|
| 54 |
+
)
|
| 55 |
+
criticality_reward: float = Field(
|
| 56 |
+
default=0.0,
|
| 57 |
+
description="Reward for prioritizing critical-path work appropriately.",
|
| 58 |
+
)
|
| 59 |
+
idle_penalty: float = Field(
|
| 60 |
+
default=0.0, description="Penalty for leaving workers idle."
|
| 61 |
+
)
|
| 62 |
+
invalid_action_penalty: float = Field(
|
| 63 |
+
default=0.0,
|
| 64 |
+
description="Penalty for malformed or infeasible actions.",
|
| 65 |
+
)
|
| 66 |
+
terminal_makespan_score: float = Field(
|
| 67 |
+
default=0.0,
|
| 68 |
+
description="Terminal score based on final schedule quality.",
|
| 69 |
+
)
|
| 70 |
+
unfinished_task_penalty: float = Field(
|
| 71 |
+
default=0.0,
|
| 72 |
+
description="Terminal penalty for unfinished work at episode end.",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class FailureEventType(str, Enum):
|
| 77 |
+
"""Failure events surfaced to agents and the UI."""
|
| 78 |
+
|
| 79 |
+
WORKER_OUTAGE_START = "worker_outage_start"
|
| 80 |
+
WORKER_OUTAGE_END = "worker_outage_end"
|
| 81 |
+
TASK_RETRY_FAILURE = "task_retry_failure"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class WorkflowFailureEvent(BaseModel):
|
| 85 |
+
"""Structured failure event emitted by the environment."""
|
| 86 |
+
|
| 87 |
+
event_type: FailureEventType = Field(..., description="Failure category.")
|
| 88 |
+
time: int = Field(..., ge=0, description="Simulated time when the event was observed.")
|
| 89 |
+
task_id: str | None = Field(default=None, description="Task affected by the event, if any.")
|
| 90 |
+
worker_delta: int = Field(default=0, description="Net temporary change in usable workers.")
|
| 91 |
+
duration: int | None = Field(default=None, ge=0, description="Outage duration when applicable.")
|
| 92 |
+
detail: str = Field(default="", description="Short human-readable summary.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WorkflowTaskView(BaseModel):
|
| 96 |
+
"""Compact task payload used in observations and the future UI."""
|
| 97 |
+
|
| 98 |
+
task_id: str = Field(..., description="Stable task identifier.")
|
| 99 |
+
status: TaskStatus = Field(..., description="Current task lifecycle state.")
|
| 100 |
+
duration: int = Field(
|
| 101 |
+
..., ge=1, description="Task runtime in simulated time units."
|
| 102 |
+
)
|
| 103 |
+
priority: int = Field(..., ge=0, description="Priority weight for the task.")
|
| 104 |
+
dependencies: list[str] = Field(
|
| 105 |
+
default_factory=list,
|
| 106 |
+
description="Upstream task ids that must complete first.",
|
| 107 |
+
)
|
| 108 |
+
deadline: int | None = Field(
|
| 109 |
+
default=None,
|
| 110 |
+
ge=0,
|
| 111 |
+
description="Optional deadline in simulated time units.",
|
| 112 |
+
)
|
| 113 |
+
criticality: float | None = Field(
|
| 114 |
+
default=None,
|
| 115 |
+
description="Derived importance score from the DAG structure.",
|
| 116 |
+
)
|
| 117 |
+
slack: float | None = Field(
|
| 118 |
+
default=None,
|
| 119 |
+
description="Derived slack estimate for scheduling decisions.",
|
| 120 |
+
)
|
| 121 |
+
downstream_count: int = Field(
|
| 122 |
+
default=0,
|
| 123 |
+
ge=0,
|
| 124 |
+
description="Count of downstream dependents reachable from this task.",
|
| 125 |
+
)
|
| 126 |
+
start_time: int | None = Field(
|
| 127 |
+
default=None,
|
| 128 |
+
ge=0,
|
| 129 |
+
description="Simulated start time if the task is running or completed.",
|
| 130 |
+
)
|
| 131 |
+
end_time: int | None = Field(
|
| 132 |
+
default=None,
|
| 133 |
+
ge=0,
|
| 134 |
+
description="Simulated end time if the task is completed or scheduled to finish.",
|
| 135 |
+
)
|
| 136 |
+
attempt_count: int = Field(
|
| 137 |
+
default=0,
|
| 138 |
+
ge=0,
|
| 139 |
+
description="Number of retry attempts already consumed by this task.",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class WorkflowTaskSpec(BaseModel):
|
| 144 |
+
"""Static task specification generated at episode reset."""
|
| 145 |
+
|
| 146 |
+
task_id: str = Field(..., description="Stable task identifier.")
|
| 147 |
+
duration: int = Field(..., ge=1, description="Task runtime in simulated time units.")
|
| 148 |
+
priority: int = Field(..., ge=0, description="Priority weight for the task.")
|
| 149 |
+
dependencies: list[str] = Field(
|
| 150 |
+
default_factory=list,
|
| 151 |
+
description="Upstream task ids that must complete first.",
|
| 152 |
+
)
|
| 153 |
+
dependents: list[str] = Field(
|
| 154 |
+
default_factory=list,
|
| 155 |
+
description="Downstream task ids that depend on this task.",
|
| 156 |
+
)
|
| 157 |
+
deadline: int | None = Field(
|
| 158 |
+
default=None,
|
| 159 |
+
ge=0,
|
| 160 |
+
description="Optional deadline in simulated time units.",
|
| 161 |
+
)
|
| 162 |
+
downstream_count: int = Field(
|
| 163 |
+
default=0,
|
| 164 |
+
ge=0,
|
| 165 |
+
description="Number of downstream tasks reachable from this node.",
|
| 166 |
+
)
|
| 167 |
+
critical_path_length: int = Field(
|
| 168 |
+
default=0,
|
| 169 |
+
ge=0,
|
| 170 |
+
description="Duration-weighted path length from this task to a sink.",
|
| 171 |
+
)
|
| 172 |
+
earliest_start: int = Field(
|
| 173 |
+
default=0,
|
| 174 |
+
ge=0,
|
| 175 |
+
description="Earliest feasible start time under dependency constraints.",
|
| 176 |
+
)
|
| 177 |
+
slack: int = Field(
|
| 178 |
+
default=0,
|
| 179 |
+
ge=0,
|
| 180 |
+
description="Scheduling slack measured in simulated time units.",
|
| 181 |
+
)
|
| 182 |
+
criticality: float = Field(
|
| 183 |
+
default=0.0,
|
| 184 |
+
description="Normalized importance score derived from critical path and downstream impact.",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ProgressSummary(BaseModel):
|
| 189 |
+
"""Counts by task lifecycle state."""
|
| 190 |
+
|
| 191 |
+
total: int = Field(default=0, ge=0)
|
| 192 |
+
blocked: int = Field(default=0, ge=0)
|
| 193 |
+
ready: int = Field(default=0, ge=0)
|
| 194 |
+
running: int = Field(default=0, ge=0)
|
| 195 |
+
completed: int = Field(default=0, ge=0)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class EpisodeConfig(BaseModel):
|
| 199 |
+
"""Reset-time knobs that define the episode."""
|
| 200 |
+
|
| 201 |
+
preset: DifficultyPreset = Field(
|
| 202 |
+
default=DifficultyPreset.EASY,
|
| 203 |
+
description="Difficulty preset for the episode generator.",
|
| 204 |
+
)
|
| 205 |
+
seed: int = Field(
|
| 206 |
+
default=0, description="Seed for deterministic episode generation."
|
| 207 |
+
)
|
| 208 |
+
worker_count: int = Field(
|
| 209 |
+
default=2,
|
| 210 |
+
ge=1,
|
| 211 |
+
description="Number of identical workers available to the scheduler.",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class GraderTarget(BaseModel):
|
| 216 |
+
"""High-level target bands for each preset's grader."""
|
| 217 |
+
|
| 218 |
+
description: str = Field(..., description="What good performance means for the preset.")
|
| 219 |
+
score_band_hint: str = Field(..., description="Human-readable interpretation of scores.")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class DifficultyPresetConfig(BaseModel):
|
| 223 |
+
"""Concrete generator knobs for a preset."""
|
| 224 |
+
|
| 225 |
+
preset: DifficultyPreset = Field(..., description="Preset identifier.")
|
| 226 |
+
min_tasks: int = Field(..., ge=2)
|
| 227 |
+
max_tasks: int = Field(..., ge=2)
|
| 228 |
+
edge_probability: float = Field(..., ge=0.0, le=1.0)
|
| 229 |
+
duration_min: int = Field(..., ge=1)
|
| 230 |
+
duration_max: int = Field(..., ge=1)
|
| 231 |
+
priority_min: int = Field(..., ge=0)
|
| 232 |
+
priority_max: int = Field(..., ge=0)
|
| 233 |
+
worker_count: int = Field(..., ge=1)
|
| 234 |
+
deadline_tightness: float = Field(
|
| 235 |
+
...,
|
| 236 |
+
ge=0.0,
|
| 237 |
+
description="Larger values mean tighter deadlines.",
|
| 238 |
+
)
|
| 239 |
+
time_budget_multiplier: float | None = Field(
|
| 240 |
+
default=None,
|
| 241 |
+
gt=0.0,
|
| 242 |
+
description="Optional multiplier over the theoretical lower-bound makespan.",
|
| 243 |
+
)
|
| 244 |
+
worker_outage_rate: float = Field(
|
| 245 |
+
default=0.0,
|
| 246 |
+
ge=0.0,
|
| 247 |
+
le=1.0,
|
| 248 |
+
description="Chance of a hard-mode worker outage being sampled on a wait transition.",
|
| 249 |
+
)
|
| 250 |
+
worker_outage_duration_min: int = Field(
|
| 251 |
+
default=0,
|
| 252 |
+
ge=0,
|
| 253 |
+
description="Minimum outage duration in simulated time units.",
|
| 254 |
+
)
|
| 255 |
+
worker_outage_duration_max: int = Field(
|
| 256 |
+
default=0,
|
| 257 |
+
ge=0,
|
| 258 |
+
description="Maximum outage duration in simulated time units.",
|
| 259 |
+
)
|
| 260 |
+
task_retry_failure_rate: float = Field(
|
| 261 |
+
default=0.0,
|
| 262 |
+
ge=0.0,
|
| 263 |
+
le=1.0,
|
| 264 |
+
description="Chance that a hard-mode task completion becomes a retry failure.",
|
| 265 |
+
)
|
| 266 |
+
max_task_retries: int = Field(
|
| 267 |
+
default=0,
|
| 268 |
+
ge=0,
|
| 269 |
+
description="Maximum number of retry failures a task may suffer before it must complete.",
|
| 270 |
+
)
|
| 271 |
+
grader_target: GraderTarget = Field(
|
| 272 |
+
...,
|
| 273 |
+
description="Preset-specific grader interpretation.",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class WorkflowEpisodeSpec(BaseModel):
|
| 278 |
+
"""Static episode description produced by the generator."""
|
| 279 |
+
|
| 280 |
+
config: EpisodeConfig = Field(..., description="Reset-time configuration.")
|
| 281 |
+
preset_config: DifficultyPresetConfig = Field(..., description="Resolved preset parameters.")
|
| 282 |
+
tasks: list[WorkflowTaskSpec] = Field(..., description="Generated workflow tasks.")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class WorkflowEnvStateSnapshot(BaseModel):
|
| 286 |
+
"""Serializable environment state for the current episode."""
|
| 287 |
+
|
| 288 |
+
episode_id: str = Field(..., description="Stable current episode identifier.")
|
| 289 |
+
current_time: int = Field(default=0, ge=0, description="Current simulated time.")
|
| 290 |
+
task_statuses: dict[str, TaskStatus] = Field(
|
| 291 |
+
default_factory=dict,
|
| 292 |
+
description="Current task status by task id.",
|
| 293 |
+
)
|
| 294 |
+
running_task_ids: list[str] = Field(
|
| 295 |
+
default_factory=list,
|
| 296 |
+
description="Tasks currently consuming workers.",
|
| 297 |
+
)
|
| 298 |
+
completed_task_ids: list[str] = Field(
|
| 299 |
+
default_factory=list,
|
| 300 |
+
description="Tasks that have completed.",
|
| 301 |
+
)
|
| 302 |
+
ready_task_ids: list[str] = Field(
|
| 303 |
+
default_factory=list,
|
| 304 |
+
description="Tasks currently ready for dispatch.",
|
| 305 |
+
)
|
| 306 |
+
blocked_task_ids: list[str] = Field(
|
| 307 |
+
default_factory=list,
|
| 308 |
+
description="Tasks still blocked on dependencies.",
|
| 309 |
+
)
|
| 310 |
+
task_start_times: dict[str, int] = Field(
|
| 311 |
+
default_factory=dict,
|
| 312 |
+
description="Simulated start time by task id.",
|
| 313 |
+
)
|
| 314 |
+
task_end_times: dict[str, int] = Field(
|
| 315 |
+
default_factory=dict,
|
| 316 |
+
description="Simulated completion time by task id.",
|
| 317 |
+
)
|
| 318 |
+
task_remaining_dependencies: dict[str, int] = Field(
|
| 319 |
+
default_factory=dict,
|
| 320 |
+
description="Remaining unfinished prerequisites by task id.",
|
| 321 |
+
)
|
| 322 |
+
task_assigned_finish_times: dict[str, int] = Field(
|
| 323 |
+
default_factory=dict,
|
| 324 |
+
description="Predicted completion times for currently running tasks.",
|
| 325 |
+
)
|
| 326 |
+
task_attempt_counts: dict[str, int] = Field(
|
| 327 |
+
default_factory=dict,
|
| 328 |
+
description="Retry attempts consumed by each task.",
|
| 329 |
+
)
|
| 330 |
+
cumulative_busy_time: int = Field(
|
| 331 |
+
default=0,
|
| 332 |
+
ge=0,
|
| 333 |
+
description="Aggregate worker busy time accrued so far.",
|
| 334 |
+
)
|
| 335 |
+
time_budget: int | None = Field(
|
| 336 |
+
default=None,
|
| 337 |
+
ge=0,
|
| 338 |
+
description="Optional terminal time budget for the episode.",
|
| 339 |
+
)
|
| 340 |
+
degraded_workers: int = Field(
|
| 341 |
+
default=0,
|
| 342 |
+
ge=0,
|
| 343 |
+
description="Workers temporarily removed from usable capacity.",
|
| 344 |
+
)
|
| 345 |
+
active_worker_outage_until: int | None = Field(
|
| 346 |
+
default=None,
|
| 347 |
+
ge=0,
|
| 348 |
+
description="Time when the current worker outage expires, if any.",
|
| 349 |
+
)
|
| 350 |
+
recent_failure_events: list[WorkflowFailureEvent] = Field(
|
| 351 |
+
default_factory=list,
|
| 352 |
+
description="Failure events generated on the latest transition.",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class SuccessMetrics(BaseModel):
|
| 357 |
+
"""Primary quality metrics used for grading and demos."""
|
| 358 |
+
|
| 359 |
+
makespan: int | None = Field(
|
| 360 |
+
default=None, description="Total simulated completion time."
|
| 361 |
+
)
|
| 362 |
+
worker_utilization: float | None = Field(
|
| 363 |
+
default=None,
|
| 364 |
+
description="Fraction of available worker time that was used.",
|
| 365 |
+
)
|
| 366 |
+
deadline_miss_count: int = Field(
|
| 367 |
+
default=0, ge=0, description="Missed task deadlines."
|
| 368 |
+
)
|
| 369 |
+
unfinished_task_count: int = Field(
|
| 370 |
+
default=0, ge=0, description="Tasks left incomplete at terminal time."
|
| 371 |
+
)
|
| 372 |
+
weighted_priority_completion: float | None = Field(
|
| 373 |
+
default=None,
|
| 374 |
+
description="Priority-weighted on-time completion score.",
|
| 375 |
+
)
|
| 376 |
+
benchmark_score: float | None = Field(
|
| 377 |
+
default=None,
|
| 378 |
+
description="Deterministic terminal benchmark score in the 0.0-1.0 range.",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
class WorkflowArenaAction(Action):
|
| 383 |
+
"""Strict action space for the workflow scheduler."""
|
| 384 |
+
|
| 385 |
+
action_type: WorkflowActionType = Field(
|
| 386 |
+
...,
|
| 387 |
+
description="Dispatch ready tasks or wait for the next completion event.",
|
| 388 |
+
)
|
| 389 |
+
task_ids: list[str] = Field(
|
| 390 |
+
default_factory=list,
|
| 391 |
+
description="Task ids to dispatch. Must be empty for wait().",
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class WorkflowArenaObservation(Observation):
|
| 396 |
+
"""Compact, typed observation contract for WorkflowArena."""
|
| 397 |
+
|
| 398 |
+
instruction: str = Field(
|
| 399 |
+
default=(
|
| 400 |
+
"Schedule dependency-constrained workflow tasks on limited workers using "
|
| 401 |
+
"dispatch(task_ids=[...]) or wait()."
|
| 402 |
+
),
|
| 403 |
+
description="Short prompt shown to inference agents.",
|
| 404 |
+
)
|
| 405 |
+
config: EpisodeConfig = Field(
|
| 406 |
+
default_factory=EpisodeConfig,
|
| 407 |
+
description="Episode generation settings.",
|
| 408 |
+
)
|
| 409 |
+
current_time: int = Field(default=0, ge=0, description="Current simulated time.")
|
| 410 |
+
total_workers: int = Field(default=2, ge=1, description="Total identical workers.")
|
| 411 |
+
effective_workers: int = Field(
|
| 412 |
+
default=2,
|
| 413 |
+
ge=0,
|
| 414 |
+
description="Usable workers after temporary degradation is applied.",
|
| 415 |
+
)
|
| 416 |
+
degraded_workers: int = Field(
|
| 417 |
+
default=0,
|
| 418 |
+
ge=0,
|
| 419 |
+
description="Workers currently unavailable due to outages.",
|
| 420 |
+
)
|
| 421 |
+
free_workers: int = Field(default=2, ge=0, description="Currently idle workers.")
|
| 422 |
+
time_budget: int | None = Field(
|
| 423 |
+
default=None,
|
| 424 |
+
ge=0,
|
| 425 |
+
description="Optional terminal time budget for the current episode.",
|
| 426 |
+
)
|
| 427 |
+
time_remaining: int | None = Field(
|
| 428 |
+
default=None,
|
| 429 |
+
description="Remaining time until the episode budget expires, if budgeted.",
|
| 430 |
+
)
|
| 431 |
+
progress: ProgressSummary = Field(
|
| 432 |
+
default_factory=ProgressSummary,
|
| 433 |
+
description="Task counts by lifecycle state.",
|
| 434 |
+
)
|
| 435 |
+
ready_tasks: list[WorkflowTaskView] = Field(
|
| 436 |
+
default_factory=list,
|
| 437 |
+
description="Ready tasks eligible for dispatch.",
|
| 438 |
+
)
|
| 439 |
+
running_tasks: list[WorkflowTaskView] = Field(
|
| 440 |
+
default_factory=list,
|
| 441 |
+
description="Tasks currently consuming workers.",
|
| 442 |
+
)
|
| 443 |
+
completed_tasks: list[WorkflowTaskView] = Field(
|
| 444 |
+
default_factory=list,
|
| 445 |
+
description="Tasks already completed.",
|
| 446 |
+
)
|
| 447 |
+
blocked_tasks: list[WorkflowTaskView] = Field(
|
| 448 |
+
default_factory=list,
|
| 449 |
+
description="Tasks still waiting on dependencies.",
|
| 450 |
+
)
|
| 451 |
+
last_reward_breakdown: RewardBreakdown = Field(
|
| 452 |
+
default_factory=RewardBreakdown,
|
| 453 |
+
description="Per-step reward channel breakdown.",
|
| 454 |
+
)
|
| 455 |
+
cumulative_reward: float = Field(default=0.0, description="Running total reward.")
|
| 456 |
+
success_metrics: SuccessMetrics = Field(
|
| 457 |
+
default_factory=SuccessMetrics,
|
| 458 |
+
description="Primary schedule quality metrics.",
|
| 459 |
+
)
|
| 460 |
+
note: str | None = Field(
|
| 461 |
+
default=None,
|
| 462 |
+
description="Short environment note about the latest transition.",
|
| 463 |
+
)
|
| 464 |
+
validation_error: str | None = Field(
|
| 465 |
+
default=None,
|
| 466 |
+
description="Explicit invalid-action explanation when the previous action failed.",
|
| 467 |
+
)
|
| 468 |
+
termination_reason: str | None = Field(
|
| 469 |
+
default=None,
|
| 470 |
+
description="Terminal reason when the episode ended unsuccessfully.",
|
| 471 |
+
)
|
| 472 |
+
benchmark_score: float | None = Field(
|
| 473 |
+
default=None,
|
| 474 |
+
description="Top-level bounded benchmark score for easier client access.",
|
| 475 |
+
)
|
| 476 |
+
recent_failure_events: list[WorkflowFailureEvent] = Field(
|
| 477 |
+
default_factory=list,
|
| 478 |
+
description="Failure events generated on the latest accepted transition.",
|
| 479 |
+
)
|
| 480 |
+
received_action: dict[str, object] | None = Field(
|
| 481 |
+
default=None,
|
| 482 |
+
description="Last action accepted by the server for logging and prompting.",
|
| 483 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: workflow_arena
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
presets.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Difficulty presets for WorkflowArena."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from workflow_arena.models import DifficultyPreset, DifficultyPresetConfig, GraderTarget
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
PRESET_CONFIGS: dict[DifficultyPreset, DifficultyPresetConfig] = {
|
| 15 |
+
DifficultyPreset.EASY: DifficultyPresetConfig(
|
| 16 |
+
preset=DifficultyPreset.EASY,
|
| 17 |
+
min_tasks=8,
|
| 18 |
+
max_tasks=12,
|
| 19 |
+
edge_probability=0.14,
|
| 20 |
+
duration_min=1,
|
| 21 |
+
duration_max=4,
|
| 22 |
+
priority_min=1,
|
| 23 |
+
priority_max=4,
|
| 24 |
+
worker_count=3,
|
| 25 |
+
deadline_tightness=0.22,
|
| 26 |
+
time_budget_multiplier=None,
|
| 27 |
+
worker_outage_rate=0.0,
|
| 28 |
+
worker_outage_duration_min=0,
|
| 29 |
+
worker_outage_duration_max=0,
|
| 30 |
+
task_retry_failure_rate=0.0,
|
| 31 |
+
max_task_retries=0,
|
| 32 |
+
grader_target=GraderTarget(
|
| 33 |
+
description=(
|
| 34 |
+
"Reward agents that keep workers utilized and avoid obvious idle time on a "
|
| 35 |
+
"small, low-pressure workflow."
|
| 36 |
+
),
|
| 37 |
+
score_band_hint="0.8+ means near-greedy scheduling, 0.5 is acceptable, below 0.3 is weak.",
|
| 38 |
+
),
|
| 39 |
+
),
|
| 40 |
+
DifficultyPreset.MEDIUM: DifficultyPresetConfig(
|
| 41 |
+
preset=DifficultyPreset.MEDIUM,
|
| 42 |
+
min_tasks=12,
|
| 43 |
+
max_tasks=18,
|
| 44 |
+
edge_probability=0.22,
|
| 45 |
+
duration_min=1,
|
| 46 |
+
duration_max=6,
|
| 47 |
+
priority_min=1,
|
| 48 |
+
priority_max=6,
|
| 49 |
+
worker_count=4,
|
| 50 |
+
deadline_tightness=0.40,
|
| 51 |
+
time_budget_multiplier=1.6,
|
| 52 |
+
worker_outage_rate=0.0,
|
| 53 |
+
worker_outage_duration_min=0,
|
| 54 |
+
worker_outage_duration_max=0,
|
| 55 |
+
task_retry_failure_rate=0.0,
|
| 56 |
+
max_task_retries=0,
|
| 57 |
+
grader_target=GraderTarget(
|
| 58 |
+
description=(
|
| 59 |
+
"Reward agents that balance utilization, deadline adherence, and critical-path "
|
| 60 |
+
"awareness on a moderately branching workflow."
|
| 61 |
+
),
|
| 62 |
+
score_band_hint="0.75+ is strong, 0.45 to 0.75 is competitive, below 0.3 misses core tradeoffs.",
|
| 63 |
+
),
|
| 64 |
+
),
|
| 65 |
+
DifficultyPreset.HARD: DifficultyPresetConfig(
|
| 66 |
+
preset=DifficultyPreset.HARD,
|
| 67 |
+
min_tasks=22,
|
| 68 |
+
max_tasks=36,
|
| 69 |
+
edge_probability=0.37,
|
| 70 |
+
duration_min=2,
|
| 71 |
+
duration_max=9,
|
| 72 |
+
priority_min=1,
|
| 73 |
+
priority_max=8,
|
| 74 |
+
worker_count=2,
|
| 75 |
+
deadline_tightness=0.78,
|
| 76 |
+
time_budget_multiplier=1.45,
|
| 77 |
+
worker_outage_rate=0.2,
|
| 78 |
+
worker_outage_duration_min=2,
|
| 79 |
+
worker_outage_duration_max=4,
|
| 80 |
+
task_retry_failure_rate=0.12,
|
| 81 |
+
max_task_retries=1,
|
| 82 |
+
grader_target=GraderTarget(
|
| 83 |
+
description=(
|
| 84 |
+
"Reward agents that identify and schedule long-running critical tasks early while "
|
| 85 |
+
"protecting high-priority deadlines under frequent worker-capacity bottlenecks."
|
| 86 |
+
),
|
| 87 |
+
score_band_hint="0.7+ is excellent, 0.4 to 0.7 is competent, below 0.25 is poor planning.",
|
| 88 |
+
),
|
| 89 |
+
),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_preset_config(preset: DifficultyPreset) -> DifficultyPresetConfig:
|
| 94 |
+
"""Return the immutable config for a preset."""
|
| 95 |
+
|
| 96 |
+
return PRESET_CONFIGS[preset].model_copy(deep=True)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-workflow_arena"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Workflow Arena environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.2",
|
| 21 |
+
"gradio>=5.0.0",
|
| 22 |
+
"plotly>=5.24.0",
|
| 23 |
+
# Environment-specific dependencies
|
| 24 |
+
# Add all dependencies needed for your environment here
|
| 25 |
+
# Examples:
|
| 26 |
+
# "numpy>=1.19.0",
|
| 27 |
+
# "torch>=2.0.0",
|
| 28 |
+
# "gymnasium>=0.29.0",
|
| 29 |
+
# "openspiel>=1.0.0",
|
| 30 |
+
# "smolagents>=1.22.0,<2",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.optional-dependencies]
|
| 34 |
+
dev = [
|
| 35 |
+
"pytest>=8.0.0",
|
| 36 |
+
"pytest-cov>=4.0.0",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.scripts]
|
| 40 |
+
# Server entry point - enables running via: uv run --project . server
|
| 41 |
+
# or: python -m workflow_arena.server.app
|
| 42 |
+
server = "workflow_arena.server.app:main"
|
| 43 |
+
|
| 44 |
+
[tool.setuptools]
|
| 45 |
+
include-package-data = true
|
| 46 |
+
packages = ["workflow_arena", "workflow_arena.server"]
|
| 47 |
+
package-dir = { "workflow_arena" = ".", "workflow_arena.server" = "server" }
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=workflow_arena
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Workflow Arena environment server components."""
|
| 8 |
+
|
| 9 |
+
from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["WorkflowArenaEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Workflow Arena Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes the WorkflowArenaEnvironment
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Endpoints:
|
| 14 |
+
- POST /reset: Reset the environment
|
| 15 |
+
- POST /step: Execute an action
|
| 16 |
+
- GET /state: Get current environment state
|
| 17 |
+
- GET /schema: Get action/observation schemas
|
| 18 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
# Development (with auto-reload):
|
| 22 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 23 |
+
|
| 24 |
+
# Production:
|
| 25 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 26 |
+
|
| 27 |
+
# Or run directly:
|
| 28 |
+
python -m server.app
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import gradio as gr
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from openenv.core.env_server.http_server import create_app
|
| 35 |
+
except Exception as e: # pragma: no cover
|
| 36 |
+
raise ImportError(
|
| 37 |
+
"openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
|
| 38 |
+
) from e
|
| 39 |
+
|
| 40 |
+
from workflow_arena.models import WorkflowArenaAction, WorkflowArenaObservation
|
| 41 |
+
from workflow_arena.server.ui import create_gradio_app
|
| 42 |
+
from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Create the app with web interface and README integration
|
| 46 |
+
app = create_app(
|
| 47 |
+
WorkflowArenaEnvironment,
|
| 48 |
+
WorkflowArenaAction,
|
| 49 |
+
WorkflowArenaObservation,
|
| 50 |
+
env_name="workflow_arena",
|
| 51 |
+
max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Mount Gradio UI at root — MUST be after all API routes to avoid catchall interference
|
| 55 |
+
_gradio_app = create_gradio_app()
|
| 56 |
+
app = gr.mount_gradio_app(app, _gradio_app, path="/")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 60 |
+
"""
|
| 61 |
+
Entry point for direct execution via uv run or python -m.
|
| 62 |
+
|
| 63 |
+
This function enables running the server without Docker:
|
| 64 |
+
uv run --project . server
|
| 65 |
+
uv run --project . server --port 8001
|
| 66 |
+
python -m workflow_arena.server.app
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
host: Host address to bind to (default: "0.0.0.0")
|
| 70 |
+
port: Port number to listen on (default: 8000)
|
| 71 |
+
|
| 72 |
+
For production deployments, consider using uvicorn directly with
|
| 73 |
+
multiple workers:
|
| 74 |
+
uvicorn workflow_arena.server.app:app --workers 4
|
| 75 |
+
"""
|
| 76 |
+
import uvicorn
|
| 77 |
+
|
| 78 |
+
uvicorn.run(app, host=host, port=port)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
import argparse
|
| 83 |
+
|
| 84 |
+
parser = argparse.ArgumentParser()
|
| 85 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
if args.port == 8000:
|
| 88 |
+
main()
|
| 89 |
+
else:
|
| 90 |
+
main(port=args.port)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
gradio>=5.0.0
|
| 5 |
+
plotly>=5.24.0
|
| 6 |
+
|
server/ui.py
ADDED
|
@@ -0,0 +1,1270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Interactive Gradio UI for WorkflowArena."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
from types import SimpleNamespace
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import plotly.graph_objects as go
|
| 17 |
+
|
| 18 |
+
from workflow_arena.models import DifficultyPreset, TaskStatus, WorkflowActionType, WorkflowArenaAction
|
| 19 |
+
from workflow_arena.presets import get_preset_config
|
| 20 |
+
from workflow_arena.server.workflow_arena_environment import WorkflowArenaEnvironment
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
Session = dict[str, Any]
|
| 24 |
+
|
| 25 |
+
DETAIL_HEADERS = [
|
| 26 |
+
"Task",
|
| 27 |
+
"Priority",
|
| 28 |
+
"Duration",
|
| 29 |
+
"Deadline",
|
| 30 |
+
"Criticality",
|
| 31 |
+
"Slack",
|
| 32 |
+
"Deps",
|
| 33 |
+
"Downstream",
|
| 34 |
+
"Attempts",
|
| 35 |
+
"Start",
|
| 36 |
+
"End",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
PRESET_BRIEFS = {
|
| 40 |
+
DifficultyPreset.EASY.value: {
|
| 41 |
+
"label": "Warm-up Flow",
|
| 42 |
+
"summary": "Small DAG, softer deadlines, and fewer traps. Good for learning how dispatch and wait interact.",
|
| 43 |
+
"focus": "Keep workers busy, avoid empty waits, and build intuition for parallel batches.",
|
| 44 |
+
"mechanics": "No hard time budget and no failure events.",
|
| 45 |
+
},
|
| 46 |
+
DifficultyPreset.MEDIUM.value: {
|
| 47 |
+
"label": "Balanced Pressure",
|
| 48 |
+
"summary": "Tighter dependencies and more timing pressure. Scheduling mistakes start to compound.",
|
| 49 |
+
"focus": "Balance urgency, downstream unlocks, and worker utilization.",
|
| 50 |
+
"mechanics": "Adds a fixed time budget and terminal penalty for unfinished work.",
|
| 51 |
+
},
|
| 52 |
+
DifficultyPreset.HARD.value: {
|
| 53 |
+
"label": "Critical Path Sprint",
|
| 54 |
+
"summary": "Dense DAGs, tighter deadlines, and much less room for idle capacity.",
|
| 55 |
+
"focus": "Protect the critical path and use every free slot intentionally.",
|
| 56 |
+
"mechanics": "Adds a tighter time budget plus seeded worker outages and task retry failures.",
|
| 57 |
+
},
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
CSS = """
|
| 61 |
+
.gradio-container {
|
| 62 |
+
background:
|
| 63 |
+
radial-gradient(circle at top left, rgba(216, 116, 76, 0.14), transparent 28%),
|
| 64 |
+
radial-gradient(circle at top right, rgba(201, 157, 92, 0.10), transparent 24%),
|
| 65 |
+
linear-gradient(180deg, #fbf4ea 0%, #f4e7d6 100%);
|
| 66 |
+
color: #2d241c;
|
| 67 |
+
font-family: "IBM Plex Sans", "Avenir Next", "Segoe UI", sans-serif;
|
| 68 |
+
}
|
| 69 |
+
.wa-shell {max-width: 1380px; margin: 0 auto; padding: 10px 10px 30px;}
|
| 70 |
+
.wa-title {margin-bottom: 18px;}
|
| 71 |
+
.wa-title h1 {margin: 0; font-size: 2.6rem; line-height: 1; letter-spacing: -0.04em; color: #3e2618;}
|
| 72 |
+
.wa-title p {margin: 10px 0 0; max-width: 920px; font-size: 1rem; color: #745f50;}
|
| 73 |
+
.wa-hero {display: grid; grid-template-columns: 1.15fr 0.85fr; gap: 18px; margin-bottom: 18px;}
|
| 74 |
+
.wa-card {
|
| 75 |
+
background: rgba(255, 252, 247, 0.92);
|
| 76 |
+
border: 1px solid rgba(139, 110, 84, 0.12);
|
| 77 |
+
border-radius: 24px;
|
| 78 |
+
box-shadow: 0 18px 60px rgba(114, 84, 51, 0.10);
|
| 79 |
+
backdrop-filter: blur(16px);
|
| 80 |
+
}
|
| 81 |
+
.wa-control-card {padding: 20px;}
|
| 82 |
+
.wa-control-card h3,
|
| 83 |
+
.wa-panel h3,
|
| 84 |
+
.wa-playbook h3 {
|
| 85 |
+
margin: 0 0 8px;
|
| 86 |
+
font-size: 0.8rem;
|
| 87 |
+
letter-spacing: 0.12em;
|
| 88 |
+
text-transform: uppercase;
|
| 89 |
+
color: #9f5b33;
|
| 90 |
+
}
|
| 91 |
+
.wa-control-card p,
|
| 92 |
+
.wa-panel p,
|
| 93 |
+
.wa-playbook p {margin: 0; color: #715e50; line-height: 1.5;}
|
| 94 |
+
.wa-control-grid {display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 12px; align-items: end; margin-top: 14px;}
|
| 95 |
+
.wa-control-buttons {display: flex; gap: 10px; align-items: center; margin-top: 12px;}
|
| 96 |
+
.wa-inline-buttons {display: flex; gap: 10px; align-items: center; flex-wrap: wrap;}
|
| 97 |
+
.wa-compact-accordion {margin-top: 14px;}
|
| 98 |
+
.wa-compact-accordion .label-wrap span {font-size: 0.85rem;}
|
| 99 |
+
.wa-problem-box {
|
| 100 |
+
padding: 4px 2px 2px;
|
| 101 |
+
border-radius: 18px;
|
| 102 |
+
}
|
| 103 |
+
.wa-problem-box strong {color: #8d4f2d;}
|
| 104 |
+
.wa-problem-box p {margin: 0 0 8px; color: #6d594c;}
|
| 105 |
+
.wa-problem-box p:last-child {margin-bottom: 0;}
|
| 106 |
+
.wa-preset-card {padding: 20px; min-height: 100%;}
|
| 107 |
+
.wa-preset-card .eyebrow {font-size: 0.75rem; letter-spacing: 0.14em; text-transform: uppercase; color: #b16d3d;}
|
| 108 |
+
.wa-preset-card .name {margin-top: 8px; font-size: 1.7rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
|
| 109 |
+
.wa-preset-card .summary {margin-top: 10px; color: #6e594a; line-height: 1.55;}
|
| 110 |
+
.wa-preset-meta {margin-top: 12px; color: #7b6554; line-height: 1.5;}
|
| 111 |
+
.wa-preset-focus {
|
| 112 |
+
margin-top: 14px;
|
| 113 |
+
padding: 12px 14px;
|
| 114 |
+
border-radius: 18px;
|
| 115 |
+
background: linear-gradient(135deg, rgba(222, 143, 93, 0.12), rgba(242, 210, 171, 0.28));
|
| 116 |
+
border: 1px solid rgba(174, 117, 72, 0.14);
|
| 117 |
+
}
|
| 118 |
+
.wa-topbar {display: grid; grid-template-columns: 1.2fr 1fr 1fr 1fr 1fr 1fr; gap: 12px; margin: 0 0 16px;}
|
| 119 |
+
.wa-stat {
|
| 120 |
+
background: linear-gradient(180deg, #fff8ef 0%, #f9efe2 100%);
|
| 121 |
+
border: 1px solid rgba(168, 130, 95, 0.16);
|
| 122 |
+
border-radius: 20px;
|
| 123 |
+
padding: 16px 18px;
|
| 124 |
+
}
|
| 125 |
+
.wa-stat .label {font-size: 0.7rem; letter-spacing: 0.12em; text-transform: uppercase; color: #a56c43;}
|
| 126 |
+
.wa-stat .value {margin-top: 6px; font-size: 1.65rem; font-weight: 700; color: #3e2618;}
|
| 127 |
+
.wa-stat .sub {margin-top: 4px; font-size: 0.82rem; color: #7a6657;}
|
| 128 |
+
.wa-banner {
|
| 129 |
+
border-radius: 24px;
|
| 130 |
+
padding: 18px 20px;
|
| 131 |
+
border: 1px solid rgba(177, 107, 70, 0.18);
|
| 132 |
+
background: linear-gradient(135deg, rgba(245, 186, 144, 0.35), rgba(255, 251, 245, 0.96));
|
| 133 |
+
color: #35271f;
|
| 134 |
+
margin-bottom: 16px;
|
| 135 |
+
}
|
| 136 |
+
.wa-banner.invalid {
|
| 137 |
+
background: linear-gradient(135deg, rgba(247, 202, 196, 0.92), rgba(255, 246, 244, 0.98));
|
| 138 |
+
border-color: rgba(180, 84, 69, 0.22);
|
| 139 |
+
}
|
| 140 |
+
.wa-banner.done {
|
| 141 |
+
background: linear-gradient(135deg, rgba(219, 229, 195, 0.9), rgba(255, 251, 244, 0.98));
|
| 142 |
+
border-color: rgba(112, 139, 90, 0.22);
|
| 143 |
+
}
|
| 144 |
+
.wa-banner-top {display: flex; justify-content: space-between; gap: 16px; align-items: flex-start;}
|
| 145 |
+
.wa-banner .status {
|
| 146 |
+
display: inline-block;
|
| 147 |
+
padding: 5px 10px;
|
| 148 |
+
border-radius: 999px;
|
| 149 |
+
font-size: 0.72rem;
|
| 150 |
+
font-weight: 700;
|
| 151 |
+
letter-spacing: 0.08em;
|
| 152 |
+
text-transform: uppercase;
|
| 153 |
+
background: rgba(164, 97, 59, 0.12);
|
| 154 |
+
}
|
| 155 |
+
.wa-banner .meta {margin-top: 8px; font-size: 0.9rem; color: #7f6654;}
|
| 156 |
+
.wa-banner .note {margin-top: 12px; font-size: 1rem; line-height: 1.5; color: #49372c;}
|
| 157 |
+
.wa-banner-grid {display: grid; grid-template-columns: repeat(4, minmax(0, 1fr)); gap: 10px; margin-top: 14px;}
|
| 158 |
+
.wa-banner-metric {
|
| 159 |
+
padding: 12px 14px;
|
| 160 |
+
border-radius: 18px;
|
| 161 |
+
background: rgba(255, 253, 248, 0.76);
|
| 162 |
+
border: 1px solid rgba(178, 132, 96, 0.12);
|
| 163 |
+
}
|
| 164 |
+
.wa-banner-metric span {display: block; font-size: 0.72rem; letter-spacing: 0.08em; text-transform: uppercase; color: #a16b45;}
|
| 165 |
+
.wa-banner-metric strong {display: block; margin-top: 6px; font-size: 1.1rem; color: #3d281c;}
|
| 166 |
+
.wa-progress {height: 10px; margin-top: 14px; border-radius: 999px; overflow: hidden; background: rgba(120, 83, 54, 0.10);}
|
| 167 |
+
.wa-progress-fill {height: 100%; background: linear-gradient(90deg, #e49157 0%, #f1c27a 100%);}
|
| 168 |
+
.wa-main {display: grid; grid-template-columns: 1.18fr 0.82fr; gap: 18px; align-items: start;}
|
| 169 |
+
.wa-left-stack,
|
| 170 |
+
.wa-right-stack {display: grid; gap: 18px;}
|
| 171 |
+
.wa-panel {padding: 18px 18px 16px;}
|
| 172 |
+
.wa-playbook {padding: 18px;}
|
| 173 |
+
.wa-playbook-header {display: flex; justify-content: space-between; gap: 12px; align-items: center; margin-bottom: 12px;}
|
| 174 |
+
.wa-playbook-title {font-size: 1.4rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
|
| 175 |
+
.wa-chip-row {display: flex; flex-wrap: wrap; gap: 8px; margin-top: 12px;}
|
| 176 |
+
.wa-chip {
|
| 177 |
+
display: inline-flex;
|
| 178 |
+
align-items: center;
|
| 179 |
+
padding: 7px 10px;
|
| 180 |
+
border-radius: 999px;
|
| 181 |
+
background: rgba(172, 121, 80, 0.10);
|
| 182 |
+
color: #7b4f32;
|
| 183 |
+
font-size: 0.84rem;
|
| 184 |
+
font-weight: 600;
|
| 185 |
+
}
|
| 186 |
+
.wa-lane-header {display: flex; justify-content: space-between; gap: 12px; align-items: flex-start; margin-bottom: 8px;}
|
| 187 |
+
.wa-lane-title {font-size: 1.32rem; font-weight: 700; letter-spacing: -0.03em; color: #3e2618;}
|
| 188 |
+
.wa-lane-copy {font-size: 0.96rem; color: #78624f;}
|
| 189 |
+
.wa-hint {
|
| 190 |
+
margin-bottom: 14px;
|
| 191 |
+
padding: 12px 14px;
|
| 192 |
+
border-radius: 18px;
|
| 193 |
+
background: rgba(174, 126, 88, 0.08);
|
| 194 |
+
border: 1px solid rgba(174, 126, 88, 0.12);
|
| 195 |
+
color: #6d4a32;
|
| 196 |
+
}
|
| 197 |
+
.wa-card-grid {display: grid; grid-template-columns: repeat(auto-fit, minmax(235px, 1fr)); gap: 12px;}
|
| 198 |
+
.wa-task-card {
|
| 199 |
+
background: linear-gradient(180deg, #fffaf2 0%, #f7ecde 100%);
|
| 200 |
+
border: 1px solid rgba(175, 135, 100, 0.18);
|
| 201 |
+
border-radius: 22px;
|
| 202 |
+
padding: 14px;
|
| 203 |
+
color: #34261d;
|
| 204 |
+
}
|
| 205 |
+
.wa-task-card.running {background: linear-gradient(180deg, #f4ebe0 0%, #ecdcca 100%);}
|
| 206 |
+
.wa-task-card.recommended {outline: 2px solid rgba(226, 145, 87, 0.9); outline-offset: 2px;}
|
| 207 |
+
.wa-task-head {display: flex; justify-content: space-between; gap: 10px; align-items: flex-start; margin-bottom: 10px;}
|
| 208 |
+
.wa-task-name {font-size: 1.08rem; font-weight: 700; color: #3c271b;}
|
| 209 |
+
.wa-badge {
|
| 210 |
+
display: inline-flex;
|
| 211 |
+
align-items: center;
|
| 212 |
+
padding: 4px 8px;
|
| 213 |
+
border-radius: 999px;
|
| 214 |
+
background: rgba(170, 123, 84, 0.12);
|
| 215 |
+
color: #825336;
|
| 216 |
+
font-size: 0.68rem;
|
| 217 |
+
font-weight: 700;
|
| 218 |
+
letter-spacing: 0.08em;
|
| 219 |
+
text-transform: uppercase;
|
| 220 |
+
}
|
| 221 |
+
.wa-badge.urgent {background: rgba(208, 108, 97, 0.16); color: #8c3c35;}
|
| 222 |
+
.wa-badge.active {background: rgba(151, 179, 120, 0.18); color: #5f7142;}
|
| 223 |
+
.wa-badge.recommended {background: rgba(229, 166, 93, 0.20); color: #86501f;}
|
| 224 |
+
.wa-badge.retry {background: rgba(176, 141, 78, 0.18); color: #7a5a22;}
|
| 225 |
+
.wa-task-meta {display: flex; flex-wrap: wrap; gap: 8px; margin-bottom: 10px;}
|
| 226 |
+
.wa-task-meta span {
|
| 227 |
+
padding: 5px 8px;
|
| 228 |
+
border-radius: 999px;
|
| 229 |
+
background: rgba(179, 139, 104, 0.10);
|
| 230 |
+
font-size: 0.76rem;
|
| 231 |
+
color: #77553b;
|
| 232 |
+
}
|
| 233 |
+
.wa-metrics {display: grid; grid-template-columns: repeat(2, minmax(0, 1fr)); gap: 10px 12px;}
|
| 234 |
+
.wa-metric span {display: block; font-size: 0.68rem; letter-spacing: 0.08em; text-transform: uppercase; color: #a26c45;}
|
| 235 |
+
.wa-metric strong {display: block; margin-top: 4px; font-size: 0.98rem; color: #35261d;}
|
| 236 |
+
.wa-empty {
|
| 237 |
+
padding: 20px;
|
| 238 |
+
border-radius: 20px;
|
| 239 |
+
border: 1px dashed rgba(176, 133, 98, 0.26);
|
| 240 |
+
background: rgba(178, 143, 112, 0.06);
|
| 241 |
+
color: #7b6656;
|
| 242 |
+
text-align: center;
|
| 243 |
+
}
|
| 244 |
+
.wa-action-row {display: flex; flex-wrap: wrap; gap: 10px; margin-top: 14px;}
|
| 245 |
+
.wa-button-primary button {background: linear-gradient(135deg, #d97b4b, #c95f34) !important; color: #fff7f0 !important; border: none !important;}
|
| 246 |
+
.wa-button-secondary button {background: #8f5b3b !important; color: #fff8f2 !important; border: none !important;}
|
| 247 |
+
.wa-button-ghost button {background: rgba(180, 132, 96, 0.08) !important; color: #7a4d31 !important; border: 1px solid rgba(180, 132, 96, 0.16) !important;}
|
| 248 |
+
.wa-plot-wrap {padding: 10px 10px 2px;}
|
| 249 |
+
.wa-footer-stack {display: grid; gap: 18px; margin-top: 18px;}
|
| 250 |
+
.wa-accordion {border-radius: 20px !important; overflow: hidden;}
|
| 251 |
+
@media (max-width: 1080px) {
|
| 252 |
+
.wa-hero,
|
| 253 |
+
.wa-main {grid-template-columns: 1fr;}
|
| 254 |
+
.wa-topbar {grid-template-columns: repeat(2, minmax(0, 1fr));}
|
| 255 |
+
}
|
| 256 |
+
@media (max-width: 760px) {
|
| 257 |
+
.wa-control-grid {grid-template-columns: 1fr;}
|
| 258 |
+
.wa-banner-grid {grid-template-columns: repeat(2, minmax(0, 1fr));}
|
| 259 |
+
.wa-topbar {grid-template-columns: 1fr;}
|
| 260 |
+
}
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _blank_session() -> Session:
|
| 265 |
+
return {"env": WorkflowArenaEnvironment(), "observation": None, "history": []}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _fmt_num(value: Any, digits: int = 3) -> str:
|
| 269 |
+
if value is None:
|
| 270 |
+
return "—"
|
| 271 |
+
if isinstance(value, float):
|
| 272 |
+
return f"{value:.{digits}f}"
|
| 273 |
+
return str(value)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _preset_html(preset: str) -> str:
|
| 277 |
+
brief = PRESET_BRIEFS.get(preset, PRESET_BRIEFS[DifficultyPreset.EASY.value])
|
| 278 |
+
preset_config = get_preset_config(DifficultyPreset(preset))
|
| 279 |
+
budget_note = (
|
| 280 |
+
"No fixed time budget."
|
| 281 |
+
if preset_config.time_budget_multiplier is None
|
| 282 |
+
else f"Time budget uses {preset_config.time_budget_multiplier:.2f}x the lower-bound makespan."
|
| 283 |
+
)
|
| 284 |
+
return (
|
| 285 |
+
'<div class="wa-preset-card wa-card">'
|
| 286 |
+
'<div class="eyebrow">Preset brief</div>'
|
| 287 |
+
f'<div class="name">{brief["label"]}</div>'
|
| 288 |
+
f'<div class="summary">{brief["summary"]}</div>'
|
| 289 |
+
f'<div class="wa-preset-meta">{budget_note} {brief["mechanics"]}</div>'
|
| 290 |
+
f'<div class="wa-preset-focus"><strong>What matters now:</strong> {brief["focus"]}</div>'
|
| 291 |
+
"</div>"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _status_text(observation: Any) -> tuple[str, str]:
|
| 296 |
+
if observation.validation_error:
|
| 297 |
+
return "Invalid action", "bad"
|
| 298 |
+
if observation.done and observation.termination_reason:
|
| 299 |
+
return "Episode terminated", "bad"
|
| 300 |
+
if observation.done:
|
| 301 |
+
return "Workflow completed", "good"
|
| 302 |
+
if observation.free_workers == 0 and observation.running_tasks:
|
| 303 |
+
return "Wait required", ""
|
| 304 |
+
if observation.ready_tasks:
|
| 305 |
+
return "Ready to dispatch", ""
|
| 306 |
+
return "Waiting on completions", ""
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _recommended_task_ids(observation: Any) -> list[str]:
|
| 310 |
+
if observation is None or observation.done or observation.free_workers <= 0:
|
| 311 |
+
return []
|
| 312 |
+
ready_tasks = list(observation.ready_tasks)
|
| 313 |
+
if not ready_tasks:
|
| 314 |
+
return []
|
| 315 |
+
time_remaining = observation.time_remaining
|
| 316 |
+
ranked = sorted(
|
| 317 |
+
ready_tasks,
|
| 318 |
+
key=lambda task: (
|
| 319 |
+
time_remaining is not None and task.duration > time_remaining,
|
| 320 |
+
max(0, task.duration - time_remaining) if time_remaining is not None else 0,
|
| 321 |
+
task.slack if task.slack is not None else 1_000_000,
|
| 322 |
+
task.deadline if task.deadline is not None else 1_000_000,
|
| 323 |
+
-(task.criticality or 0.0),
|
| 324 |
+
-task.priority,
|
| 325 |
+
task.duration,
|
| 326 |
+
task.task_id,
|
| 327 |
+
),
|
| 328 |
+
)
|
| 329 |
+
return [task.task_id for task in ranked[: observation.free_workers]]
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _dispatch_window(observation: Any) -> tuple[int, int, int]:
|
| 333 |
+
ready_count = len(observation.ready_tasks)
|
| 334 |
+
free_workers = max(0, observation.free_workers)
|
| 335 |
+
dispatchable_now = min(ready_count, free_workers)
|
| 336 |
+
overflow_ready = max(0, ready_count - free_workers)
|
| 337 |
+
return ready_count, dispatchable_now, overflow_ready
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _topbar_html(observation: Any) -> str:
|
| 341 |
+
completed = observation.progress.completed
|
| 342 |
+
total = max(1, observation.progress.total)
|
| 343 |
+
score = observation.benchmark_score
|
| 344 |
+
if score is None:
|
| 345 |
+
score = observation.success_metrics.benchmark_score
|
| 346 |
+
time_sub = (
|
| 347 |
+
f"{observation.time_remaining} remaining"
|
| 348 |
+
if observation.time_remaining is not None
|
| 349 |
+
else "simulation clock"
|
| 350 |
+
)
|
| 351 |
+
worker_sub = (
|
| 352 |
+
f"idle / usable of {observation.total_workers}"
|
| 353 |
+
if getattr(observation, "effective_workers", observation.total_workers) != observation.total_workers
|
| 354 |
+
else "free / total"
|
| 355 |
+
)
|
| 356 |
+
cards = [
|
| 357 |
+
("State", _status_text(observation)[0], f"{observation.progress.ready} ready / {observation.progress.running} running"),
|
| 358 |
+
(
|
| 359 |
+
"Workers",
|
| 360 |
+
f"{observation.free_workers}/{getattr(observation, 'effective_workers', observation.total_workers)}",
|
| 361 |
+
worker_sub,
|
| 362 |
+
),
|
| 363 |
+
("Completed", f"{completed}/{total}", f"{round(100 * completed / total, 1)}% finished"),
|
| 364 |
+
("Reward", _fmt_num(observation.cumulative_reward, 3), "cumulative"),
|
| 365 |
+
("Time", observation.current_time, time_sub),
|
| 366 |
+
("Score", _fmt_num(score, 3), "terminal if done"),
|
| 367 |
+
]
|
| 368 |
+
return '<div class="wa-topbar">' + "".join(
|
| 369 |
+
f'<div class="wa-stat"><div class="label">{label}</div><div class="value">{value}</div><div class="sub">{sub}</div></div>'
|
| 370 |
+
for label, value, sub in cards
|
| 371 |
+
) + "</div>"
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _banner_html(observation: Any) -> str:
|
| 375 |
+
completed = observation.progress.completed
|
| 376 |
+
total = max(1, observation.progress.total)
|
| 377 |
+
progress_pct = round(100 * completed / total, 1)
|
| 378 |
+
status_text, status_kind = _status_text(observation)
|
| 379 |
+
banner_class = "wa-banner"
|
| 380 |
+
if status_kind == "bad":
|
| 381 |
+
banner_class += " invalid"
|
| 382 |
+
elif status_kind == "good":
|
| 383 |
+
banner_class += " done"
|
| 384 |
+
|
| 385 |
+
failure_note = _failure_summary(observation)
|
| 386 |
+
note = observation.note or "No environment note."
|
| 387 |
+
if observation.validation_error:
|
| 388 |
+
note = f"{note} {observation.validation_error}"
|
| 389 |
+
if failure_note:
|
| 390 |
+
note = f"{note} {failure_note}"
|
| 391 |
+
|
| 392 |
+
score = observation.benchmark_score
|
| 393 |
+
if score is None:
|
| 394 |
+
score = observation.success_metrics.benchmark_score
|
| 395 |
+
|
| 396 |
+
metric_cards = [
|
| 397 |
+
("Ready", observation.progress.ready),
|
| 398 |
+
("Running", observation.progress.running),
|
| 399 |
+
("Workers", f"{getattr(observation, 'effective_workers', observation.total_workers)}/{observation.total_workers}"),
|
| 400 |
+
(
|
| 401 |
+
"Time Left",
|
| 402 |
+
observation.time_remaining if observation.time_remaining is not None else "—",
|
| 403 |
+
),
|
| 404 |
+
]
|
| 405 |
+
|
| 406 |
+
return (
|
| 407 |
+
f'<div class="{banner_class}">'
|
| 408 |
+
'<div class="wa-banner-top">'
|
| 409 |
+
'<div>'
|
| 410 |
+
f'<span class="status">{status_text}</span>'
|
| 411 |
+
f'<div class="meta">Preset: {observation.config.preset.value} • Seed: {observation.config.seed} • Workers: {observation.total_workers}</div>'
|
| 412 |
+
f'<div class="note">{note}</div>'
|
| 413 |
+
"</div>"
|
| 414 |
+
f'<div class="meta">{completed}/{total} complete</div>'
|
| 415 |
+
"</div>"
|
| 416 |
+
'<div class="wa-banner-grid">'
|
| 417 |
+
+ "".join(
|
| 418 |
+
f'<div class="wa-banner-metric"><span>{label}</span><strong>{value}</strong></div>'
|
| 419 |
+
for label, value in metric_cards
|
| 420 |
+
)
|
| 421 |
+
+ "</div>"
|
| 422 |
+
f'<div class="wa-progress"><div class="wa-progress-fill" style="width:{progress_pct:.1f}%"></div></div>'
|
| 423 |
+
"</div>"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _planner_html(observation: Any) -> str:
|
| 428 |
+
recommended = _recommended_task_ids(observation)
|
| 429 |
+
ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
|
| 430 |
+
if observation.done:
|
| 431 |
+
title = "Episode finished"
|
| 432 |
+
body = "Reset for another episode or inspect the final timeline and reward trace below."
|
| 433 |
+
elif observation.validation_error:
|
| 434 |
+
title = "Fix the last move"
|
| 435 |
+
body = observation.validation_error
|
| 436 |
+
elif observation.free_workers == 0 and observation.running_tasks:
|
| 437 |
+
title = "Advance time"
|
| 438 |
+
body = "All workers are occupied. Waiting is the only legal move until the next task completes."
|
| 439 |
+
elif recommended:
|
| 440 |
+
title = f"Dispatch {', '.join(recommended)}"
|
| 441 |
+
body = (
|
| 442 |
+
"These tasks minimize slack first, then prefer tighter deadlines, stronger criticality, and higher priority. "
|
| 443 |
+
f"The recommendation is capped at `{dispatchable_now}` because only `{observation.free_workers}` worker"
|
| 444 |
+
f"{'s are' if observation.free_workers != 1 else ' is'} free right now."
|
| 445 |
+
)
|
| 446 |
+
else:
|
| 447 |
+
title = "Hold for completions"
|
| 448 |
+
body = "No ready work is available. Wait until dependencies unlock new tasks."
|
| 449 |
+
|
| 450 |
+
chips = [
|
| 451 |
+
f"free workers: {observation.free_workers}",
|
| 452 |
+
f"usable workers: {getattr(observation, 'effective_workers', observation.total_workers)}",
|
| 453 |
+
f"ready queue: {ready_count}",
|
| 454 |
+
f"dispatchable now: {dispatchable_now}",
|
| 455 |
+
f"last reward: {_fmt_num(observation.reward if hasattr(observation, 'reward') else 0.0, 3)}",
|
| 456 |
+
]
|
| 457 |
+
if observation.time_remaining is not None:
|
| 458 |
+
chips.append(f"time remaining: {observation.time_remaining}")
|
| 459 |
+
if overflow_ready:
|
| 460 |
+
chips.append(f"queued beyond capacity: {overflow_ready}")
|
| 461 |
+
if observation.running_tasks:
|
| 462 |
+
next_finish = min(task.end_time or observation.current_time for task in observation.running_tasks)
|
| 463 |
+
chips.append(f"next completion: t={next_finish}")
|
| 464 |
+
if observation.degraded_workers:
|
| 465 |
+
chips.append(f"worker outage: -{observation.degraded_workers} usable")
|
| 466 |
+
|
| 467 |
+
return (
|
| 468 |
+
'<div class="wa-playbook wa-card">'
|
| 469 |
+
'<div class="wa-playbook-header">'
|
| 470 |
+
'<div>'
|
| 471 |
+
'<h3>Decision support</h3>'
|
| 472 |
+
f'<div class="wa-playbook-title">{title}</div>'
|
| 473 |
+
"</div>"
|
| 474 |
+
f'<div class="wa-chip">{_status_text(observation)[0]}</div>'
|
| 475 |
+
"</div>"
|
| 476 |
+
f'<p>{body}</p>'
|
| 477 |
+
'<div class="wa-chip-row">'
|
| 478 |
+
+ "".join(f'<div class="wa-chip">{chip}</div>' for chip in chips)
|
| 479 |
+
+ "</div>"
|
| 480 |
+
"</div>"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def _capacity_hint(observation: Any) -> str:
|
| 485 |
+
ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
|
| 486 |
+
if observation.done:
|
| 487 |
+
return "Episode finished. Review the schedule or reset to try another seed."
|
| 488 |
+
if observation.validation_error:
|
| 489 |
+
return (
|
| 490 |
+
f"Last action was rejected. Select at most {observation.free_workers} ready "
|
| 491 |
+
f"task{'s' if observation.free_workers != 1 else ''}."
|
| 492 |
+
)
|
| 493 |
+
if observation.free_workers == 0 and observation.running_tasks:
|
| 494 |
+
return "All workers are busy. Use Wait to jump to the next completion."
|
| 495 |
+
if not observation.ready_tasks:
|
| 496 |
+
return "No ready tasks available right now. Wait until dependencies unlock more work."
|
| 497 |
+
overflow_suffix = f" `{overflow_ready}` ready task(s) will stay queued." if overflow_ready else ""
|
| 498 |
+
return (
|
| 499 |
+
f"{ready_count} ready task{'s' if ready_count != 1 else ''}. "
|
| 500 |
+
f"You can dispatch up to {dispatchable_now} right now.{overflow_suffix}"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def _task_badges(task: Any, *, running: bool = False, recommended: bool = False) -> str:
|
| 505 |
+
badges: list[str] = []
|
| 506 |
+
if task.deadline is not None and task.slack is not None and task.slack <= 1:
|
| 507 |
+
badges.append('<span class="wa-badge urgent">Urgent</span>')
|
| 508 |
+
if getattr(task, "attempt_count", 0) > 0:
|
| 509 |
+
badges.append(
|
| 510 |
+
f'<span class="wa-badge retry">Retry {getattr(task, "attempt_count", 0) + 1}</span>'
|
| 511 |
+
)
|
| 512 |
+
if recommended:
|
| 513 |
+
badges.append('<span class="wa-badge recommended">Recommended</span>')
|
| 514 |
+
if running:
|
| 515 |
+
badges.append('<span class="wa-badge active">Running</span>')
|
| 516 |
+
if not badges:
|
| 517 |
+
badges.append('<span class="wa-badge">Ready</span>')
|
| 518 |
+
return "".join(badges)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _task_card(task: Any, *, running: bool = False, recommended: bool = False) -> str:
|
| 522 |
+
deps = ", ".join(task.dependencies) if task.dependencies else "None"
|
| 523 |
+
deadline = task.deadline if task.deadline is not None else "—"
|
| 524 |
+
start = task.start_time if task.start_time is not None else "—"
|
| 525 |
+
end = task.end_time if task.end_time is not None else "—"
|
| 526 |
+
classes = ["wa-task-card"]
|
| 527 |
+
if running:
|
| 528 |
+
classes.append("running")
|
| 529 |
+
if recommended:
|
| 530 |
+
classes.append("recommended")
|
| 531 |
+
return (
|
| 532 |
+
f'<div class="{" ".join(classes)}">'
|
| 533 |
+
'<div class="wa-task-head">'
|
| 534 |
+
f'<div class="wa-task-name">{task.task_id}</div>'
|
| 535 |
+
f'<div>{_task_badges(task, running=running, recommended=recommended)}</div>'
|
| 536 |
+
"</div>"
|
| 537 |
+
f'<div class="wa-task-meta"><span>deps: {deps}</span><span>downstream: {task.downstream_count}</span><span>attempts: {getattr(task, "attempt_count", 0) + 1}</span></div>'
|
| 538 |
+
'<div class="wa-metrics">'
|
| 539 |
+
f'<div class="wa-metric"><span>Deadline</span><strong>{deadline}</strong></div>'
|
| 540 |
+
f'<div class="wa-metric"><span>Duration</span><strong>{task.duration}</strong></div>'
|
| 541 |
+
f'<div class="wa-metric"><span>Priority</span><strong>{task.priority}</strong></div>'
|
| 542 |
+
f'<div class="wa-metric"><span>Criticality</span><strong>{_fmt_num(task.criticality, 3)}</strong></div>'
|
| 543 |
+
f'<div class="wa-metric"><span>Slack</span><strong>{_fmt_num(task.slack, 1)}</strong></div>'
|
| 544 |
+
f'<div class="wa-metric"><span>{"Finish" if running else "Start"}</span><strong>{end if running else start}</strong></div>'
|
| 545 |
+
"</div>"
|
| 546 |
+
"</div>"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def _cards_html(tasks: list[Any], *, running: bool = False, recommended_ids: set[str] | None = None) -> str:
|
| 551 |
+
if not tasks:
|
| 552 |
+
message = "No tasks in this lane yet." if running else "No ready tasks available."
|
| 553 |
+
return f'<div class="wa-empty">{message}</div>'
|
| 554 |
+
recommended_ids = recommended_ids or set()
|
| 555 |
+
return '<div class="wa-card-grid">' + "".join(
|
| 556 |
+
_task_card(task, running=running, recommended=task.task_id in recommended_ids)
|
| 557 |
+
for task in tasks
|
| 558 |
+
) + "</div>"
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def _timeline_figure(observation: Any) -> go.Figure:
|
| 562 |
+
fig = go.Figure()
|
| 563 |
+
|
| 564 |
+
completed = sorted(observation.completed_tasks, key=lambda task: task.task_id)
|
| 565 |
+
running = sorted(observation.running_tasks, key=lambda task: task.task_id)
|
| 566 |
+
ready = sorted(observation.ready_tasks, key=lambda task: task.task_id)
|
| 567 |
+
|
| 568 |
+
timeline_tasks = completed + running
|
| 569 |
+
task_ids = [task.task_id for task in timeline_tasks] + [task.task_id for task in ready]
|
| 570 |
+
|
| 571 |
+
if completed:
|
| 572 |
+
fig.add_trace(
|
| 573 |
+
go.Bar(
|
| 574 |
+
x=[max(0, (task.end_time or 0) - (task.start_time or 0)) for task in completed],
|
| 575 |
+
y=[task.task_id for task in completed],
|
| 576 |
+
base=[task.start_time or 0 for task in completed],
|
| 577 |
+
orientation="h",
|
| 578 |
+
name="Completed",
|
| 579 |
+
marker_color="#85c88a",
|
| 580 |
+
hovertemplate=(
|
| 581 |
+
"<b>%{y}</b><br>Status: Completed<br>Start: %{base}<br>"
|
| 582 |
+
"Duration: %{x}<extra></extra>"
|
| 583 |
+
),
|
| 584 |
+
)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
if running:
|
| 588 |
+
fig.add_trace(
|
| 589 |
+
go.Bar(
|
| 590 |
+
x=[
|
| 591 |
+
max(
|
| 592 |
+
0,
|
| 593 |
+
(task.end_time or observation.current_time) - (task.start_time or observation.current_time),
|
| 594 |
+
)
|
| 595 |
+
for task in running
|
| 596 |
+
],
|
| 597 |
+
y=[task.task_id for task in running],
|
| 598 |
+
base=[task.start_time or observation.current_time for task in running],
|
| 599 |
+
orientation="h",
|
| 600 |
+
name="Running",
|
| 601 |
+
marker_color="#d88a5b",
|
| 602 |
+
hovertemplate=(
|
| 603 |
+
"<b>%{y}</b><br>Status: Running<br>Start: %{base}<br>"
|
| 604 |
+
"Allocated span: %{x}<extra></extra>"
|
| 605 |
+
),
|
| 606 |
+
)
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
if ready:
|
| 610 |
+
fig.add_trace(
|
| 611 |
+
go.Scatter(
|
| 612 |
+
x=[observation.current_time] * len(ready),
|
| 613 |
+
y=[task.task_id for task in ready],
|
| 614 |
+
mode="markers",
|
| 615 |
+
name="Ready",
|
| 616 |
+
marker=dict(color="#9e6a43", size=11, symbol="diamond"),
|
| 617 |
+
customdata=[[task.deadline, task.priority, task.duration] for task in ready],
|
| 618 |
+
hovertemplate=(
|
| 619 |
+
"<b>%{y}</b><br>Status: Ready<br>Current time: %{x}<br>"
|
| 620 |
+
"Deadline: %{customdata[0]}<br>Priority: %{customdata[1]}<br>"
|
| 621 |
+
"Duration: %{customdata[2]}<extra></extra>"
|
| 622 |
+
),
|
| 623 |
+
)
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if not task_ids:
|
| 627 |
+
task_ids = ["No tasks yet"]
|
| 628 |
+
fig.add_annotation(
|
| 629 |
+
text="Reset an episode to populate the workflow timeline.",
|
| 630 |
+
x=0.5,
|
| 631 |
+
y=0.5,
|
| 632 |
+
xref="paper",
|
| 633 |
+
yref="paper",
|
| 634 |
+
showarrow=False,
|
| 635 |
+
font=dict(color="#8c6f58", size=14),
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
horizon_candidates = [observation.current_time + 1]
|
| 639 |
+
horizon_candidates.extend(task.end_time or 0 for task in completed)
|
| 640 |
+
horizon_candidates.extend(task.end_time or observation.current_time for task in running)
|
| 641 |
+
x_max = max(horizon_candidates) + 1
|
| 642 |
+
|
| 643 |
+
fig.add_vline(
|
| 644 |
+
x=observation.current_time,
|
| 645 |
+
line_width=2,
|
| 646 |
+
line_dash="dash",
|
| 647 |
+
line_color="#9f6b48",
|
| 648 |
+
annotation_text="Now",
|
| 649 |
+
annotation_position="top left",
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
fig.update_layout(
|
| 653 |
+
barmode="overlay",
|
| 654 |
+
height=max(280, 90 + 34 * len(task_ids)),
|
| 655 |
+
margin=dict(l=10, r=10, t=44, b=18),
|
| 656 |
+
paper_bgcolor="#ffffff",
|
| 657 |
+
plot_bgcolor="#ffffff",
|
| 658 |
+
font=dict(color="#4d382b", family="IBM Plex Sans, Arial, sans-serif"),
|
| 659 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, x=0),
|
| 660 |
+
title=dict(text="Workflow Timeline", x=0.02, font=dict(size=18)),
|
| 661 |
+
xaxis=dict(
|
| 662 |
+
title="Simulated Time",
|
| 663 |
+
range=[0, x_max],
|
| 664 |
+
gridcolor="#eadfce",
|
| 665 |
+
zeroline=False,
|
| 666 |
+
linecolor="#cfb79c",
|
| 667 |
+
title_font=dict(color="#6c4a33"),
|
| 668 |
+
tickfont=dict(color="#6c4a33"),
|
| 669 |
+
),
|
| 670 |
+
yaxis=dict(
|
| 671 |
+
title="Tasks",
|
| 672 |
+
categoryorder="array",
|
| 673 |
+
categoryarray=list(reversed(task_ids)),
|
| 674 |
+
gridcolor="#f2e8da",
|
| 675 |
+
linecolor="#cfb79c",
|
| 676 |
+
title_font=dict(color="#6c4a33"),
|
| 677 |
+
tickfont=dict(color="#6c4a33"),
|
| 678 |
+
),
|
| 679 |
+
)
|
| 680 |
+
return fig
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def _detail_rows(tasks: list[Any]) -> list[list[Any]]:
|
| 684 |
+
return [
|
| 685 |
+
[
|
| 686 |
+
task.task_id,
|
| 687 |
+
task.priority,
|
| 688 |
+
task.duration,
|
| 689 |
+
task.deadline if task.deadline is not None else "—",
|
| 690 |
+
_fmt_num(task.criticality, 3),
|
| 691 |
+
_fmt_num(task.slack, 1),
|
| 692 |
+
len(task.dependencies),
|
| 693 |
+
task.downstream_count,
|
| 694 |
+
getattr(task, "attempt_count", 0) + 1,
|
| 695 |
+
task.start_time if task.start_time is not None else "—",
|
| 696 |
+
task.end_time if task.end_time is not None else "—",
|
| 697 |
+
]
|
| 698 |
+
for task in tasks
|
| 699 |
+
]
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def _failure_summary(observation: Any) -> str:
|
| 703 |
+
events = getattr(observation, "recent_failure_events", []) or []
|
| 704 |
+
if not events:
|
| 705 |
+
return ""
|
| 706 |
+
return " ".join(event.detail for event in events if getattr(event, "detail", ""))
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def _selection_markdown(selected_task_ids: list[str], observation: Any) -> str:
|
| 710 |
+
if observation is None:
|
| 711 |
+
return "No episode yet. Reset an episode to start building a dispatch batch."
|
| 712 |
+
|
| 713 |
+
task_map = {task.task_id: task for task in observation.ready_tasks}
|
| 714 |
+
selected_tasks = [task_map[task_id] for task_id in selected_task_ids if task_id in task_map]
|
| 715 |
+
capacity = max(0, observation.free_workers)
|
| 716 |
+
ready_count, dispatchable_now, overflow_ready = _dispatch_window(observation)
|
| 717 |
+
|
| 718 |
+
if not selected_tasks:
|
| 719 |
+
recommended = _recommended_task_ids(observation)
|
| 720 |
+
if not recommended:
|
| 721 |
+
return (
|
| 722 |
+
f"**Dispatch builder**\n\nReady queue: `{ready_count}`. "
|
| 723 |
+
f"Dispatchable now: `{dispatchable_now}`."
|
| 724 |
+
)
|
| 725 |
+
overflow_suffix = f" `{overflow_ready}` ready task(s) stay queued after dispatch." if overflow_ready else ""
|
| 726 |
+
return (
|
| 727 |
+
f"**Dispatch builder**\n\nNo tasks selected yet. Ready queue: `{ready_count}`. "
|
| 728 |
+
f"Recommended batch: `{', '.join(recommended)}`. Dispatch cap now: `{dispatchable_now}`.{overflow_suffix}"
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
total_priority = sum(task.priority for task in selected_tasks)
|
| 732 |
+
shortest_finish = observation.current_time + min(task.duration for task in selected_tasks)
|
| 733 |
+
longest_finish = observation.current_time + max(task.duration for task in selected_tasks)
|
| 734 |
+
warnings: list[str] = []
|
| 735 |
+
if len(selected_tasks) > capacity:
|
| 736 |
+
warnings.append(f"Selection exceeds capacity by `{len(selected_tasks) - capacity}`.")
|
| 737 |
+
|
| 738 |
+
warning_text = "\n\n" + " ".join(warnings) if warnings else ""
|
| 739 |
+
return (
|
| 740 |
+
f"**Dispatch builder**\n\nSelected `{len(selected_tasks)}` task(s) for `{capacity}` free slot(s). "
|
| 741 |
+
f"Priority sum: `{total_priority}`. Earliest completion: `t={shortest_finish}`. "
|
| 742 |
+
f"Longest in-flight span: `t={longest_finish}`.{warning_text}"
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def _reward_markdown(observation: Any) -> str:
|
| 747 |
+
breakdown = observation.last_reward_breakdown
|
| 748 |
+
rows = [
|
| 749 |
+
("completion", breakdown.completion_reward),
|
| 750 |
+
("utilization", breakdown.utilization_reward),
|
| 751 |
+
("deadline", breakdown.deadline_reward),
|
| 752 |
+
("criticality", breakdown.criticality_reward),
|
| 753 |
+
("idle", breakdown.idle_penalty),
|
| 754 |
+
("invalid", breakdown.invalid_action_penalty),
|
| 755 |
+
("terminal", breakdown.terminal_makespan_score),
|
| 756 |
+
("unfinished", breakdown.unfinished_task_penalty),
|
| 757 |
+
]
|
| 758 |
+
lines = ["| Channel | Value |", "| --- | ---: |"]
|
| 759 |
+
lines.extend(f"| {label} | {value:.3f} |" for label, value in rows)
|
| 760 |
+
return "\n".join(lines)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def _history_markdown(history: list[dict[str, Any]]) -> str:
|
| 764 |
+
if not history:
|
| 765 |
+
return "No actions yet."
|
| 766 |
+
lines: list[str] = []
|
| 767 |
+
for item in history[-12:]:
|
| 768 |
+
reward = _fmt_num(item.get("reward"), 3)
|
| 769 |
+
suffix = f" • error: `{item['error']}`" if item.get("error") else ""
|
| 770 |
+
note = item.get("note") or ""
|
| 771 |
+
lines.append(f"**{item['label']}** at `t={item['time']}` • reward `{reward}`{suffix} \n{note}")
|
| 772 |
+
return "\n\n".join(reversed(lines))
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _blank_observation_view() -> Any:
|
| 776 |
+
return SimpleNamespace(
|
| 777 |
+
reward=0.0,
|
| 778 |
+
progress=SimpleNamespace(completed=0, ready=0, running=0, blocked=0, total=1),
|
| 779 |
+
benchmark_score=None,
|
| 780 |
+
success_metrics=SimpleNamespace(benchmark_score=None, unfinished_task_count=0),
|
| 781 |
+
free_workers=0,
|
| 782 |
+
effective_workers=0,
|
| 783 |
+
degraded_workers=0,
|
| 784 |
+
total_workers=0,
|
| 785 |
+
time_budget=None,
|
| 786 |
+
time_remaining=None,
|
| 787 |
+
cumulative_reward=0.0,
|
| 788 |
+
current_time=0,
|
| 789 |
+
done=False,
|
| 790 |
+
termination_reason=None,
|
| 791 |
+
validation_error=None,
|
| 792 |
+
completed_tasks=[],
|
| 793 |
+
ready_tasks=[],
|
| 794 |
+
running_tasks=[],
|
| 795 |
+
blocked_tasks=[],
|
| 796 |
+
recent_failure_events=[],
|
| 797 |
+
note="Reset an episode to start scheduling.",
|
| 798 |
+
config=SimpleNamespace(preset=SimpleNamespace(value=DifficultyPreset.EASY.value), seed=0),
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def _empty_updates(session: Session):
|
| 803 |
+
empty_rows: list[list[Any]] = []
|
| 804 |
+
blank = _blank_observation_view()
|
| 805 |
+
return (
|
| 806 |
+
session,
|
| 807 |
+
_preset_html(DifficultyPreset.EASY.value),
|
| 808 |
+
_topbar_html(blank),
|
| 809 |
+
_banner_html(blank),
|
| 810 |
+
_planner_html(blank),
|
| 811 |
+
"No episode yet.",
|
| 812 |
+
_selection_markdown([], blank),
|
| 813 |
+
'<div class="wa-empty">Reset an episode to see ready tasks.</div>',
|
| 814 |
+
gr.update(choices=[], value=[]),
|
| 815 |
+
gr.update(interactive=False),
|
| 816 |
+
gr.update(interactive=False),
|
| 817 |
+
gr.update(interactive=False),
|
| 818 |
+
gr.update(interactive=False),
|
| 819 |
+
gr.update(interactive=False),
|
| 820 |
+
'<div class="wa-empty">Running tasks will appear here after dispatch.</div>',
|
| 821 |
+
_timeline_figure(blank),
|
| 822 |
+
"| Channel | Value |\n| --- | ---: |\n| — | — |",
|
| 823 |
+
"No actions yet.",
|
| 824 |
+
empty_rows,
|
| 825 |
+
empty_rows,
|
| 826 |
+
empty_rows,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def _render(session: Session):
|
| 831 |
+
observation = session.get("observation")
|
| 832 |
+
env = session.get("env")
|
| 833 |
+
history = session.get("history", [])
|
| 834 |
+
if observation is None:
|
| 835 |
+
return _empty_updates(session)
|
| 836 |
+
|
| 837 |
+
if env is None:
|
| 838 |
+
completed_rows = _detail_rows(observation.completed_tasks)
|
| 839 |
+
blocked_rows = _detail_rows(observation.blocked_tasks)
|
| 840 |
+
else:
|
| 841 |
+
completed_rows = _detail_rows(env.debug_task_views_for_status(TaskStatus.COMPLETED))
|
| 842 |
+
blocked_rows = _detail_rows(env.debug_task_views_for_status(TaskStatus.BLOCKED))
|
| 843 |
+
|
| 844 |
+
ready_choices = [task.task_id for task in observation.ready_tasks]
|
| 845 |
+
recommended_ids = _recommended_task_ids(observation)
|
| 846 |
+
can_recommend = bool(recommended_ids) and not observation.done
|
| 847 |
+
can_wait = bool(observation.running_tasks) and not observation.done
|
| 848 |
+
can_clear = bool(ready_choices) and not observation.done
|
| 849 |
+
|
| 850 |
+
return (
|
| 851 |
+
session,
|
| 852 |
+
_preset_html(observation.config.preset.value),
|
| 853 |
+
_topbar_html(observation),
|
| 854 |
+
_banner_html(observation),
|
| 855 |
+
_planner_html(observation),
|
| 856 |
+
_capacity_hint(observation),
|
| 857 |
+
_selection_markdown([], observation),
|
| 858 |
+
_cards_html(observation.ready_tasks, running=False, recommended_ids=set(recommended_ids)),
|
| 859 |
+
gr.update(choices=ready_choices, value=[]),
|
| 860 |
+
gr.update(interactive=False),
|
| 861 |
+
gr.update(interactive=can_wait),
|
| 862 |
+
gr.update(interactive=can_recommend),
|
| 863 |
+
gr.update(interactive=can_recommend),
|
| 864 |
+
gr.update(interactive=can_clear),
|
| 865 |
+
_cards_html(observation.running_tasks, running=True),
|
| 866 |
+
_timeline_figure(observation),
|
| 867 |
+
_reward_markdown(observation),
|
| 868 |
+
_history_markdown(history),
|
| 869 |
+
empty_rows := _detail_rows([]),
|
| 870 |
+
completed_rows,
|
| 871 |
+
blocked_rows,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def _append_history(
|
| 876 |
+
session: Session,
|
| 877 |
+
label: str,
|
| 878 |
+
observation: Any,
|
| 879 |
+
*,
|
| 880 |
+
reward: float | None = None,
|
| 881 |
+
error: str | None = None,
|
| 882 |
+
) -> Session:
|
| 883 |
+
history = list(session.get("history", []))
|
| 884 |
+
history.append(
|
| 885 |
+
{
|
| 886 |
+
"label": label,
|
| 887 |
+
"time": observation.current_time,
|
| 888 |
+
"reward": reward,
|
| 889 |
+
"error": error,
|
| 890 |
+
"note": observation.note,
|
| 891 |
+
}
|
| 892 |
+
)
|
| 893 |
+
session["history"] = history
|
| 894 |
+
return session
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def _reset(preset: str, seed: float, worker_count: float, session: Session):
|
| 898 |
+
env = session.get("env") or WorkflowArenaEnvironment()
|
| 899 |
+
observation = env.reset(
|
| 900 |
+
preset=preset,
|
| 901 |
+
seed=int(seed),
|
| 902 |
+
worker_count=int(worker_count),
|
| 903 |
+
)
|
| 904 |
+
next_session = {"env": env, "observation": observation, "history": []}
|
| 905 |
+
next_session = _append_history(
|
| 906 |
+
next_session,
|
| 907 |
+
f"reset • preset `{preset}` • workers `{int(worker_count)}`",
|
| 908 |
+
observation,
|
| 909 |
+
reward=0.0,
|
| 910 |
+
error=None,
|
| 911 |
+
)
|
| 912 |
+
return _render(next_session)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def _dispatch(selected_task_ids: list[str], session: Session):
|
| 916 |
+
observation = session.get("observation")
|
| 917 |
+
env = session.get("env")
|
| 918 |
+
if env is None or observation is None:
|
| 919 |
+
return _render(_blank_session())
|
| 920 |
+
|
| 921 |
+
action = WorkflowArenaAction(
|
| 922 |
+
action_type=WorkflowActionType.DISPATCH,
|
| 923 |
+
task_ids=selected_task_ids,
|
| 924 |
+
)
|
| 925 |
+
next_observation = env.step(action)
|
| 926 |
+
next_session = {
|
| 927 |
+
"env": env,
|
| 928 |
+
"observation": next_observation,
|
| 929 |
+
"history": session.get("history", []),
|
| 930 |
+
}
|
| 931 |
+
label = "dispatch " + (", ".join(selected_task_ids) if selected_task_ids else "(none)")
|
| 932 |
+
next_session = _append_history(
|
| 933 |
+
next_session,
|
| 934 |
+
label,
|
| 935 |
+
next_observation,
|
| 936 |
+
reward=next_observation.reward,
|
| 937 |
+
error=next_observation.validation_error,
|
| 938 |
+
)
|
| 939 |
+
return _render(next_session)
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def _dispatch_recommended(session: Session):
|
| 943 |
+
observation = session.get("observation")
|
| 944 |
+
if observation is None:
|
| 945 |
+
return _render(_blank_session())
|
| 946 |
+
return _dispatch(_recommended_task_ids(observation), session)
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def _wait(session: Session):
|
| 950 |
+
observation = session.get("observation")
|
| 951 |
+
env = session.get("env")
|
| 952 |
+
if env is None or observation is None:
|
| 953 |
+
return _render(_blank_session())
|
| 954 |
+
|
| 955 |
+
action = WorkflowArenaAction(
|
| 956 |
+
action_type=WorkflowActionType.WAIT,
|
| 957 |
+
task_ids=[],
|
| 958 |
+
)
|
| 959 |
+
next_observation = env.step(action)
|
| 960 |
+
next_session = {
|
| 961 |
+
"env": env,
|
| 962 |
+
"observation": next_observation,
|
| 963 |
+
"history": session.get("history", []),
|
| 964 |
+
}
|
| 965 |
+
next_session = _append_history(
|
| 966 |
+
next_session,
|
| 967 |
+
"wait",
|
| 968 |
+
next_observation,
|
| 969 |
+
reward=next_observation.reward,
|
| 970 |
+
error=next_observation.validation_error,
|
| 971 |
+
)
|
| 972 |
+
return _render(next_session)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
def _update_selection(selected_task_ids: list[str], session: Session):
|
| 976 |
+
observation = session.get("observation")
|
| 977 |
+
if observation is None:
|
| 978 |
+
return "No episode yet.", [], gr.update(interactive=False)
|
| 979 |
+
|
| 980 |
+
task_map = {task.task_id: task for task in observation.ready_tasks}
|
| 981 |
+
selected_tasks = [task_map[task_id] for task_id in selected_task_ids if task_id in task_map]
|
| 982 |
+
can_dispatch = (
|
| 983 |
+
bool(selected_tasks)
|
| 984 |
+
and len(selected_tasks) == len(selected_task_ids)
|
| 985 |
+
and len(selected_task_ids) <= observation.free_workers
|
| 986 |
+
and not observation.done
|
| 987 |
+
)
|
| 988 |
+
return (
|
| 989 |
+
_selection_markdown(selected_task_ids, observation),
|
| 990 |
+
_detail_rows(selected_tasks),
|
| 991 |
+
gr.update(interactive=can_dispatch),
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def _select_recommended(session: Session):
|
| 996 |
+
observation = session.get("observation")
|
| 997 |
+
if observation is None:
|
| 998 |
+
return gr.update(value=[]), "No episode yet.", [], gr.update(interactive=False)
|
| 999 |
+
recommended_ids = _recommended_task_ids(observation)
|
| 1000 |
+
task_map = {task.task_id: task for task in observation.ready_tasks}
|
| 1001 |
+
selected_tasks = [task_map[task_id] for task_id in recommended_ids if task_id in task_map]
|
| 1002 |
+
return (
|
| 1003 |
+
gr.update(value=recommended_ids),
|
| 1004 |
+
_selection_markdown(recommended_ids, observation),
|
| 1005 |
+
_detail_rows(selected_tasks),
|
| 1006 |
+
gr.update(interactive=bool(recommended_ids) and not observation.done),
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def _clear_selection(session: Session):
|
| 1011 |
+
observation = session.get("observation")
|
| 1012 |
+
return (
|
| 1013 |
+
gr.update(value=[]),
|
| 1014 |
+
_selection_markdown([], observation),
|
| 1015 |
+
[],
|
| 1016 |
+
gr.update(interactive=False),
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def _random_seed() -> int:
|
| 1021 |
+
return random.randint(0, 999_999)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
def _preset_controls_update(preset: str):
|
| 1025 |
+
preset_config = get_preset_config(DifficultyPreset(preset))
|
| 1026 |
+
return _preset_html(preset), gr.update(value=preset_config.worker_count)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def create_gradio_app() -> gr.Blocks:
|
| 1030 |
+
with gr.Blocks(title="WorkflowArena") as demo:
|
| 1031 |
+
session = gr.State(_blank_session())
|
| 1032 |
+
|
| 1033 |
+
with gr.Column(elem_classes=["wa-shell"]):
|
| 1034 |
+
gr.HTML(f"<style>{CSS}</style>")
|
| 1035 |
+
gr.HTML(
|
| 1036 |
+
"""
|
| 1037 |
+
<div class="wa-title">
|
| 1038 |
+
<h1>WorkflowArena</h1>
|
| 1039 |
+
<p>Run a workflow episode like a control room instead of a raw form. Reset a seeded DAG, inspect urgency and capacity, build a legal dispatch batch, then advance time when workers are saturated.</p>
|
| 1040 |
+
</div>
|
| 1041 |
+
"""
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
with gr.Row(elem_classes=["wa-hero"]):
|
| 1045 |
+
with gr.Column(elem_classes=["wa-control-card", "wa-card"]):
|
| 1046 |
+
gr.HTML("<h3>Episode controls</h3><p>Change the preset, seed, or worker count, then reset to generate a new scheduling problem.</p>")
|
| 1047 |
+
with gr.Accordion("Problem Framing", open=False, elem_classes=["wa-accordion", "wa-compact-accordion"]):
|
| 1048 |
+
gr.HTML(
|
| 1049 |
+
"""
|
| 1050 |
+
<div class="wa-problem-box">
|
| 1051 |
+
<p><strong>What this problem is:</strong> You are scheduling a workflow where tasks depend on each other and workers are limited. At every step, the legal move is either to dispatch ready tasks to free workers or wait for the next completion event.</p>
|
| 1052 |
+
<p><strong>What good play looks like:</strong> Finish urgent and high-value work on time, keep workers utilized, and avoid delaying the critical path. Higher difficulties add a time budget and, in hard mode, failure events that reduce usable capacity or force retries.</p>
|
| 1053 |
+
</div>
|
| 1054 |
+
"""
|
| 1055 |
+
)
|
| 1056 |
+
with gr.Row(elem_classes=["wa-control-grid"]):
|
| 1057 |
+
preset = gr.Dropdown(
|
| 1058 |
+
label="Preset",
|
| 1059 |
+
choices=[preset.value for preset in DifficultyPreset],
|
| 1060 |
+
value=DifficultyPreset.EASY.value,
|
| 1061 |
+
interactive=True,
|
| 1062 |
+
)
|
| 1063 |
+
seed = gr.Number(label="Seed", value=0, precision=0, minimum=0)
|
| 1064 |
+
workers = gr.Slider(
|
| 1065 |
+
minimum=1,
|
| 1066 |
+
maximum=6,
|
| 1067 |
+
step=1,
|
| 1068 |
+
value=3,
|
| 1069 |
+
label="Workers",
|
| 1070 |
+
interactive=True,
|
| 1071 |
+
)
|
| 1072 |
+
with gr.Row(elem_classes=["wa-control-buttons", "wa-inline-buttons"]):
|
| 1073 |
+
reset_button = gr.Button(
|
| 1074 |
+
"Reset Episode",
|
| 1075 |
+
variant="primary",
|
| 1076 |
+
elem_classes=["wa-button-primary"],
|
| 1077 |
+
)
|
| 1078 |
+
random_seed_button = gr.Button(
|
| 1079 |
+
"Random Seed",
|
| 1080 |
+
variant="secondary",
|
| 1081 |
+
elem_classes=["wa-button-ghost"],
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
preset_brief = gr.HTML(value=_preset_html(DifficultyPreset.EASY.value))
|
| 1085 |
+
|
| 1086 |
+
topbar = gr.HTML()
|
| 1087 |
+
banner = gr.HTML()
|
| 1088 |
+
planner = gr.HTML()
|
| 1089 |
+
|
| 1090 |
+
with gr.Row(elem_classes=["wa-main"]):
|
| 1091 |
+
with gr.Column(elem_classes=["wa-left-stack"]):
|
| 1092 |
+
with gr.Column(elem_classes=["wa-panel", "wa-card"]):
|
| 1093 |
+
gr.HTML(
|
| 1094 |
+
"""
|
| 1095 |
+
<div class="wa-lane-header">
|
| 1096 |
+
<div>
|
| 1097 |
+
<div class="wa-lane-title">Dispatch Lane</div>
|
| 1098 |
+
<div class="wa-lane-copy">Inspect ready tasks, build a batch, and send work only when the action is legal.</div>
|
| 1099 |
+
</div>
|
| 1100 |
+
</div>
|
| 1101 |
+
"""
|
| 1102 |
+
)
|
| 1103 |
+
selection_summary = gr.Markdown(elem_classes=["wa-hint"])
|
| 1104 |
+
ready_cards = gr.HTML()
|
| 1105 |
+
ready_selector = gr.CheckboxGroup(
|
| 1106 |
+
label="Build dispatch batch",
|
| 1107 |
+
info="Choose up to the number of currently free workers.",
|
| 1108 |
+
)
|
| 1109 |
+
with gr.Row(elem_classes=["wa-action-row"]):
|
| 1110 |
+
select_recommended_button = gr.Button(
|
| 1111 |
+
"Select Recommended",
|
| 1112 |
+
variant="secondary",
|
| 1113 |
+
interactive=False,
|
| 1114 |
+
elem_classes=["wa-button-secondary"],
|
| 1115 |
+
)
|
| 1116 |
+
dispatch_recommended_button = gr.Button(
|
| 1117 |
+
"Dispatch Recommended",
|
| 1118 |
+
variant="primary",
|
| 1119 |
+
interactive=False,
|
| 1120 |
+
elem_classes=["wa-button-primary"],
|
| 1121 |
+
)
|
| 1122 |
+
clear_selection_button = gr.Button(
|
| 1123 |
+
"Clear Selection",
|
| 1124 |
+
variant="secondary",
|
| 1125 |
+
interactive=False,
|
| 1126 |
+
elem_classes=["wa-button-ghost"],
|
| 1127 |
+
)
|
| 1128 |
+
dispatch_button = gr.Button(
|
| 1129 |
+
"Dispatch Selected",
|
| 1130 |
+
variant="primary",
|
| 1131 |
+
interactive=False,
|
| 1132 |
+
elem_classes=["wa-button-primary"],
|
| 1133 |
+
)
|
| 1134 |
+
wait_button = gr.Button(
|
| 1135 |
+
"Wait",
|
| 1136 |
+
variant="secondary",
|
| 1137 |
+
interactive=False,
|
| 1138 |
+
elem_classes=["wa-button-secondary"],
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
with gr.Column(elem_classes=["wa-panel", "wa-card"]):
|
| 1142 |
+
gr.HTML(
|
| 1143 |
+
"""
|
| 1144 |
+
<div class="wa-lane-header">
|
| 1145 |
+
<div>
|
| 1146 |
+
<div class="wa-lane-title">Selected Batch</div>
|
| 1147 |
+
<div class="wa-lane-copy">This preview updates as you pick tasks from the current ready queue.</div>
|
| 1148 |
+
</div>
|
| 1149 |
+
</div>
|
| 1150 |
+
"""
|
| 1151 |
+
)
|
| 1152 |
+
selected_table = gr.Dataframe(
|
| 1153 |
+
headers=DETAIL_HEADERS,
|
| 1154 |
+
value=[],
|
| 1155 |
+
interactive=False,
|
| 1156 |
+
wrap=True,
|
| 1157 |
+
label="Dispatch preview",
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
with gr.Column(elem_classes=["wa-right-stack"]):
|
| 1161 |
+
with gr.Column(elem_classes=["wa-panel", "wa-card"]):
|
| 1162 |
+
gr.HTML(
|
| 1163 |
+
"""
|
| 1164 |
+
<div class="wa-lane-header">
|
| 1165 |
+
<div>
|
| 1166 |
+
<div class="wa-lane-title">Mission Control</div>
|
| 1167 |
+
<div class="wa-lane-copy">Use the recommendation as a guide, not a rule. The reward trace below tells you whether the choice paid off.</div>
|
| 1168 |
+
</div>
|
| 1169 |
+
</div>
|
| 1170 |
+
"""
|
| 1171 |
+
)
|
| 1172 |
+
gr.Markdown(elem_classes=["wa-hint"], value="No episode yet.")
|
| 1173 |
+
decision_hint = gr.Markdown(elem_classes=["wa-hint"])
|
| 1174 |
+
running_cards = gr.HTML()
|
| 1175 |
+
|
| 1176 |
+
with gr.Column(elem_classes=["wa-plot-wrap", "wa-card"]):
|
| 1177 |
+
timeline_plot = gr.Plot(label="Workflow Timeline")
|
| 1178 |
+
|
| 1179 |
+
with gr.Accordion("Reward Breakdown", open=False, elem_classes=["wa-accordion"]):
|
| 1180 |
+
reward_markdown = gr.Markdown()
|
| 1181 |
+
|
| 1182 |
+
with gr.Accordion("Action History", open=False, elem_classes=["wa-accordion"]):
|
| 1183 |
+
history_markdown = gr.Markdown()
|
| 1184 |
+
|
| 1185 |
+
with gr.Row(elem_classes=["wa-footer-stack"]):
|
| 1186 |
+
with gr.Accordion("Completed Tasks", open=False, elem_classes=["wa-accordion"]):
|
| 1187 |
+
completed_table = gr.Dataframe(
|
| 1188 |
+
headers=DETAIL_HEADERS,
|
| 1189 |
+
value=[],
|
| 1190 |
+
interactive=False,
|
| 1191 |
+
wrap=True,
|
| 1192 |
+
label="Completed",
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
with gr.Accordion("Blocked Tasks", open=False, elem_classes=["wa-accordion"]):
|
| 1196 |
+
blocked_table = gr.Dataframe(
|
| 1197 |
+
headers=DETAIL_HEADERS,
|
| 1198 |
+
value=[],
|
| 1199 |
+
interactive=False,
|
| 1200 |
+
wrap=True,
|
| 1201 |
+
label="Blocked",
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
outputs = [
|
| 1205 |
+
session,
|
| 1206 |
+
preset_brief,
|
| 1207 |
+
topbar,
|
| 1208 |
+
banner,
|
| 1209 |
+
planner,
|
| 1210 |
+
decision_hint,
|
| 1211 |
+
selection_summary,
|
| 1212 |
+
ready_cards,
|
| 1213 |
+
ready_selector,
|
| 1214 |
+
dispatch_button,
|
| 1215 |
+
wait_button,
|
| 1216 |
+
select_recommended_button,
|
| 1217 |
+
dispatch_recommended_button,
|
| 1218 |
+
clear_selection_button,
|
| 1219 |
+
running_cards,
|
| 1220 |
+
timeline_plot,
|
| 1221 |
+
reward_markdown,
|
| 1222 |
+
history_markdown,
|
| 1223 |
+
selected_table,
|
| 1224 |
+
completed_table,
|
| 1225 |
+
blocked_table,
|
| 1226 |
+
]
|
| 1227 |
+
|
| 1228 |
+
random_seed_button.click(_random_seed, outputs=[seed])
|
| 1229 |
+
preset.change(_preset_controls_update, inputs=[preset], outputs=[preset_brief, workers])
|
| 1230 |
+
|
| 1231 |
+
reset_button.click(
|
| 1232 |
+
_reset,
|
| 1233 |
+
inputs=[preset, seed, workers, session],
|
| 1234 |
+
outputs=outputs,
|
| 1235 |
+
)
|
| 1236 |
+
dispatch_button.click(
|
| 1237 |
+
_dispatch,
|
| 1238 |
+
inputs=[ready_selector, session],
|
| 1239 |
+
outputs=outputs,
|
| 1240 |
+
)
|
| 1241 |
+
dispatch_recommended_button.click(
|
| 1242 |
+
_dispatch_recommended,
|
| 1243 |
+
inputs=[session],
|
| 1244 |
+
outputs=outputs,
|
| 1245 |
+
)
|
| 1246 |
+
wait_button.click(
|
| 1247 |
+
_wait,
|
| 1248 |
+
inputs=[session],
|
| 1249 |
+
outputs=outputs,
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
ready_selector.change(
|
| 1253 |
+
_update_selection,
|
| 1254 |
+
inputs=[ready_selector, session],
|
| 1255 |
+
outputs=[selection_summary, selected_table, dispatch_button],
|
| 1256 |
+
)
|
| 1257 |
+
select_recommended_button.click(
|
| 1258 |
+
_select_recommended,
|
| 1259 |
+
inputs=[session],
|
| 1260 |
+
outputs=[ready_selector, selection_summary, selected_table, dispatch_button],
|
| 1261 |
+
)
|
| 1262 |
+
clear_selection_button.click(
|
| 1263 |
+
_clear_selection,
|
| 1264 |
+
inputs=[session],
|
| 1265 |
+
outputs=[ready_selector, selection_summary, selected_table, dispatch_button],
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
demo.load(lambda: _empty_updates(_blank_session()), outputs=outputs)
|
| 1269 |
+
|
| 1270 |
+
return demo
|
server/workflow_arena_environment.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""WorkflowArena event-driven workflow orchestration environment."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import random
|
| 13 |
+
from typing import Any
|
| 14 |
+
from uuid import uuid4
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server.interfaces import Environment
|
| 17 |
+
from openenv.core.env_server.types import State
|
| 18 |
+
|
| 19 |
+
from workflow_arena.generator import generate_episode
|
| 20 |
+
from workflow_arena.models import (
|
| 21 |
+
DifficultyPreset,
|
| 22 |
+
EpisodeConfig,
|
| 23 |
+
FailureEventType,
|
| 24 |
+
ProgressSummary,
|
| 25 |
+
RewardBreakdown,
|
| 26 |
+
SuccessMetrics,
|
| 27 |
+
TaskStatus,
|
| 28 |
+
WorkflowArenaAction,
|
| 29 |
+
WorkflowArenaObservation,
|
| 30 |
+
WorkflowEnvStateSnapshot,
|
| 31 |
+
WorkflowEpisodeSpec,
|
| 32 |
+
WorkflowFailureEvent,
|
| 33 |
+
WorkflowTaskSpec,
|
| 34 |
+
WorkflowTaskView,
|
| 35 |
+
WorkflowActionType,
|
| 36 |
+
)
|
| 37 |
+
from workflow_arena.presets import get_preset_config
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class WorkflowArenaEnvironment(Environment):
|
| 41 |
+
"""Resource-constrained workflow scheduler with event-driven semantics."""
|
| 42 |
+
|
| 43 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 44 |
+
STEP_LIMIT_FLOOR: int = 32
|
| 45 |
+
STEP_LIMIT_MULTIPLIER: int = 8
|
| 46 |
+
INVALID_ACTION_PENALTY: float = -0.1
|
| 47 |
+
OVERCAPACITY_INVALID_ACTION_PENALTY: float = -0.25
|
| 48 |
+
AVOIDABLE_WAIT_PENALTY_PER_SLOT: float = -0.08
|
| 49 |
+
UNFINISHED_PRIORITY_PENALTY: float = -0.02
|
| 50 |
+
OVERDUE_PRIORITY_PENALTY_PER_TICK: float = -0.005
|
| 51 |
+
MAX_RECENT_FAILURE_EVENTS: int = 6
|
| 52 |
+
|
| 53 |
+
def __init__(self):
|
| 54 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 55 |
+
self._cumulative_reward = 0.0
|
| 56 |
+
self._max_episode_steps = self.STEP_LIMIT_FLOOR
|
| 57 |
+
self._config = EpisodeConfig(
|
| 58 |
+
preset=DifficultyPreset.EASY,
|
| 59 |
+
seed=0,
|
| 60 |
+
worker_count=2,
|
| 61 |
+
)
|
| 62 |
+
self._episode_spec: WorkflowEpisodeSpec | None = None
|
| 63 |
+
self._env_state: WorkflowEnvStateSnapshot | None = None
|
| 64 |
+
self._event_rng = random.Random(0)
|
| 65 |
+
|
| 66 |
+
def _require_episode(self) -> tuple[WorkflowEpisodeSpec, WorkflowEnvStateSnapshot]:
|
| 67 |
+
if self._episode_spec is None or self._env_state is None:
|
| 68 |
+
raise RuntimeError("Environment must be reset before use.")
|
| 69 |
+
return self._episode_spec, self._env_state
|
| 70 |
+
|
| 71 |
+
def _preset_config(self):
|
| 72 |
+
episode, _ = self._require_episode()
|
| 73 |
+
return episode.preset_config
|
| 74 |
+
|
| 75 |
+
def _task_map(self) -> dict[str, WorkflowTaskSpec]:
|
| 76 |
+
episode, _ = self._require_episode()
|
| 77 |
+
return {task.task_id: task for task in episode.tasks}
|
| 78 |
+
|
| 79 |
+
def _effective_worker_capacity(
|
| 80 |
+
self, env_state: WorkflowEnvStateSnapshot | None = None
|
| 81 |
+
) -> int:
|
| 82 |
+
if env_state is None:
|
| 83 |
+
_, env_state = self._require_episode()
|
| 84 |
+
return max(0, self._config.worker_count - env_state.degraded_workers)
|
| 85 |
+
|
| 86 |
+
def _time_remaining(
|
| 87 |
+
self, env_state: WorkflowEnvStateSnapshot | None = None
|
| 88 |
+
) -> int | None:
|
| 89 |
+
if env_state is None:
|
| 90 |
+
_, env_state = self._require_episode()
|
| 91 |
+
if env_state.time_budget is None:
|
| 92 |
+
return None
|
| 93 |
+
return max(0, env_state.time_budget - env_state.current_time)
|
| 94 |
+
|
| 95 |
+
def _terminal_score(self) -> float:
|
| 96 |
+
episode, env_state = self._require_episode()
|
| 97 |
+
if env_state.current_time <= 0:
|
| 98 |
+
return 0.0
|
| 99 |
+
lower_bound = self._lower_bound_makespan(episode)
|
| 100 |
+
score = lower_bound / max(lower_bound, env_state.current_time)
|
| 101 |
+
return round(score, 4)
|
| 102 |
+
|
| 103 |
+
def _benchmark_score(self) -> float:
|
| 104 |
+
makespan_score, deadline_score, utilization_score = self._grade_components(
|
| 105 |
+
include_terminal_makespan=True
|
| 106 |
+
)
|
| 107 |
+
return round(
|
| 108 |
+
(0.5 * makespan_score) + (0.3 * deadline_score) + (0.2 * utilization_score),
|
| 109 |
+
4,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _grade_components(
|
| 113 |
+
self, *, include_terminal_makespan: bool = False
|
| 114 |
+
) -> tuple[float, float, float]:
|
| 115 |
+
episode, env_state = self._require_episode()
|
| 116 |
+
utilization = (
|
| 117 |
+
env_state.cumulative_busy_time
|
| 118 |
+
/ (env_state.current_time * self._config.worker_count)
|
| 119 |
+
if env_state.current_time > 0
|
| 120 |
+
else 0.0
|
| 121 |
+
)
|
| 122 |
+
total_priority = sum(task.priority for task in episode.tasks) or 1
|
| 123 |
+
on_time_priority = 0
|
| 124 |
+
for task in episode.tasks:
|
| 125 |
+
end_time = env_state.task_end_times.get(task.task_id)
|
| 126 |
+
if end_time is None:
|
| 127 |
+
continue
|
| 128 |
+
if task.deadline is None or end_time <= task.deadline:
|
| 129 |
+
on_time_priority += task.priority
|
| 130 |
+
deadline_score = round(on_time_priority / total_priority, 4)
|
| 131 |
+
utilization_score = round(utilization, 4)
|
| 132 |
+
makespan_score = self._terminal_score() if include_terminal_makespan else 0.0
|
| 133 |
+
return makespan_score, deadline_score, utilization_score
|
| 134 |
+
|
| 135 |
+
def _unfinished_task_penalty(self, current_time: int) -> float:
|
| 136 |
+
episode, env_state = self._require_episode()
|
| 137 |
+
penalty = 0.0
|
| 138 |
+
for task in episode.tasks:
|
| 139 |
+
if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED:
|
| 140 |
+
continue
|
| 141 |
+
penalty += self.UNFINISHED_PRIORITY_PENALTY * task.priority
|
| 142 |
+
if task.deadline is not None and current_time > task.deadline:
|
| 143 |
+
penalty += (
|
| 144 |
+
self.OVERDUE_PRIORITY_PENALTY_PER_TICK
|
| 145 |
+
* task.priority
|
| 146 |
+
* (current_time - task.deadline)
|
| 147 |
+
)
|
| 148 |
+
return round(penalty, 4)
|
| 149 |
+
|
| 150 |
+
def _success_metrics(
|
| 151 |
+
self, *, benchmark_score_override: float | None = None
|
| 152 |
+
) -> SuccessMetrics:
|
| 153 |
+
episode, env_state = self._require_episode()
|
| 154 |
+
unfinished_task_count = sum(
|
| 155 |
+
1
|
| 156 |
+
for task in episode.tasks
|
| 157 |
+
if env_state.task_statuses[task.task_id] != TaskStatus.COMPLETED
|
| 158 |
+
)
|
| 159 |
+
deadline_miss_count = sum(
|
| 160 |
+
1
|
| 161 |
+
for task in episode.tasks
|
| 162 |
+
if env_state.task_statuses[task.task_id] == TaskStatus.COMPLETED
|
| 163 |
+
and task.deadline is not None
|
| 164 |
+
and env_state.task_end_times.get(task.task_id, 0) > task.deadline
|
| 165 |
+
)
|
| 166 |
+
_, deadline_score, utilization_score = self._grade_components(
|
| 167 |
+
include_terminal_makespan=False
|
| 168 |
+
)
|
| 169 |
+
all_done = unfinished_task_count == 0
|
| 170 |
+
return SuccessMetrics(
|
| 171 |
+
makespan=env_state.current_time if all_done else None,
|
| 172 |
+
worker_utilization=utilization_score,
|
| 173 |
+
deadline_miss_count=deadline_miss_count,
|
| 174 |
+
unfinished_task_count=unfinished_task_count,
|
| 175 |
+
weighted_priority_completion=deadline_score,
|
| 176 |
+
benchmark_score=benchmark_score_override,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def _task_view(
|
| 180 |
+
self,
|
| 181 |
+
task: WorkflowTaskSpec,
|
| 182 |
+
status: TaskStatus,
|
| 183 |
+
*,
|
| 184 |
+
include_planner_hints: bool = True,
|
| 185 |
+
) -> WorkflowTaskView:
|
| 186 |
+
_, env_state = self._require_episode()
|
| 187 |
+
return WorkflowTaskView(
|
| 188 |
+
task_id=task.task_id,
|
| 189 |
+
status=status,
|
| 190 |
+
duration=task.duration,
|
| 191 |
+
priority=task.priority,
|
| 192 |
+
dependencies=task.dependencies,
|
| 193 |
+
deadline=task.deadline,
|
| 194 |
+
criticality=task.criticality if include_planner_hints else None,
|
| 195 |
+
slack=float(task.slack) if include_planner_hints else None,
|
| 196 |
+
downstream_count=task.downstream_count if include_planner_hints else 0,
|
| 197 |
+
start_time=env_state.task_start_times.get(task.task_id),
|
| 198 |
+
end_time=(
|
| 199 |
+
env_state.task_end_times.get(task.task_id)
|
| 200 |
+
or env_state.task_assigned_finish_times.get(task.task_id)
|
| 201 |
+
),
|
| 202 |
+
attempt_count=env_state.task_attempt_counts.get(task.task_id, 0),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def _task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
|
| 206 |
+
episode, env_state = self._require_episode()
|
| 207 |
+
return [
|
| 208 |
+
self._task_view(task, status, include_planner_hints=True)
|
| 209 |
+
for task in episode.tasks
|
| 210 |
+
if env_state.task_statuses[task.task_id] == status
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
def debug_task_views_for_status(self, status: TaskStatus) -> list[WorkflowTaskView]:
|
| 214 |
+
return self._task_views_for_status(status)
|
| 215 |
+
|
| 216 |
+
def _set_recent_failure_events(
|
| 217 |
+
self,
|
| 218 |
+
env_state: WorkflowEnvStateSnapshot,
|
| 219 |
+
events: list[WorkflowFailureEvent],
|
| 220 |
+
) -> None:
|
| 221 |
+
env_state.recent_failure_events = events[-self.MAX_RECENT_FAILURE_EVENTS :]
|
| 222 |
+
|
| 223 |
+
def _maybe_end_worker_outage(
|
| 224 |
+
self,
|
| 225 |
+
env_state: WorkflowEnvStateSnapshot,
|
| 226 |
+
events: list[WorkflowFailureEvent],
|
| 227 |
+
) -> None:
|
| 228 |
+
if (
|
| 229 |
+
env_state.active_worker_outage_until is not None
|
| 230 |
+
and env_state.current_time >= env_state.active_worker_outage_until
|
| 231 |
+
):
|
| 232 |
+
events.append(
|
| 233 |
+
WorkflowFailureEvent(
|
| 234 |
+
event_type=FailureEventType.WORKER_OUTAGE_END,
|
| 235 |
+
time=env_state.current_time,
|
| 236 |
+
worker_delta=1,
|
| 237 |
+
detail="Worker capacity restored.",
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
env_state.active_worker_outage_until = None
|
| 241 |
+
env_state.degraded_workers = 0
|
| 242 |
+
|
| 243 |
+
def _maybe_start_worker_outage(
|
| 244 |
+
self,
|
| 245 |
+
env_state: WorkflowEnvStateSnapshot,
|
| 246 |
+
events: list[WorkflowFailureEvent],
|
| 247 |
+
) -> None:
|
| 248 |
+
preset_config = self._preset_config()
|
| 249 |
+
if self._config.preset != DifficultyPreset.HARD:
|
| 250 |
+
return
|
| 251 |
+
if env_state.active_worker_outage_until is not None:
|
| 252 |
+
return
|
| 253 |
+
if preset_config.worker_outage_rate <= 0.0:
|
| 254 |
+
return
|
| 255 |
+
if self._event_rng.random() >= preset_config.worker_outage_rate:
|
| 256 |
+
return
|
| 257 |
+
|
| 258 |
+
duration = self._event_rng.randint(
|
| 259 |
+
preset_config.worker_outage_duration_min,
|
| 260 |
+
preset_config.worker_outage_duration_max,
|
| 261 |
+
)
|
| 262 |
+
if duration <= 0:
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
env_state.degraded_workers = min(1, self._config.worker_count)
|
| 266 |
+
env_state.active_worker_outage_until = env_state.current_time + duration
|
| 267 |
+
events.append(
|
| 268 |
+
WorkflowFailureEvent(
|
| 269 |
+
event_type=FailureEventType.WORKER_OUTAGE_START,
|
| 270 |
+
time=env_state.current_time,
|
| 271 |
+
worker_delta=-env_state.degraded_workers,
|
| 272 |
+
duration=duration,
|
| 273 |
+
detail=f"Worker outage active until t={env_state.active_worker_outage_until}.",
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def _should_retry_fail(self, task_id: str) -> bool:
|
| 278 |
+
preset_config = self._preset_config()
|
| 279 |
+
_, env_state = self._require_episode()
|
| 280 |
+
if self._config.preset != DifficultyPreset.HARD:
|
| 281 |
+
return False
|
| 282 |
+
if preset_config.task_retry_failure_rate <= 0.0:
|
| 283 |
+
return False
|
| 284 |
+
if env_state.task_attempt_counts.get(task_id, 0) >= preset_config.max_task_retries:
|
| 285 |
+
return False
|
| 286 |
+
return self._event_rng.random() < preset_config.task_retry_failure_rate
|
| 287 |
+
|
| 288 |
+
def _dispatch_potential(
|
| 289 |
+
self,
|
| 290 |
+
env_state: WorkflowEnvStateSnapshot,
|
| 291 |
+
task_map: dict[str, WorkflowTaskSpec],
|
| 292 |
+
) -> tuple[float, float]:
|
| 293 |
+
if not env_state.running_task_ids:
|
| 294 |
+
return 0.0, 0.0
|
| 295 |
+
|
| 296 |
+
episode, _ = self._require_episode()
|
| 297 |
+
max_slack = max((task.slack for task in episode.tasks), default=0)
|
| 298 |
+
utilization_component = 0.06 * (
|
| 299 |
+
len(env_state.running_task_ids) / max(1, self._config.worker_count)
|
| 300 |
+
)
|
| 301 |
+
criticality_component = 0.0
|
| 302 |
+
for task_id in env_state.running_task_ids:
|
| 303 |
+
task = task_map[task_id]
|
| 304 |
+
slack_urgency = 1.0 if max_slack <= 0 else 1.0 - (task.slack / max_slack)
|
| 305 |
+
criticality_component += (0.6 * task.criticality) + (0.4 * slack_urgency)
|
| 306 |
+
criticality_component = 0.04 * (
|
| 307 |
+
criticality_component / max(1, self._config.worker_count)
|
| 308 |
+
)
|
| 309 |
+
return round(utilization_component, 4), round(criticality_component, 4)
|
| 310 |
+
|
| 311 |
+
def _base_observation(
|
| 312 |
+
self,
|
| 313 |
+
*,
|
| 314 |
+
reward: float,
|
| 315 |
+
breakdown: RewardBreakdown,
|
| 316 |
+
note: str,
|
| 317 |
+
done: bool,
|
| 318 |
+
benchmark_score_override: float | None = None,
|
| 319 |
+
) -> WorkflowArenaObservation:
|
| 320 |
+
episode, env_state = self._require_episode()
|
| 321 |
+
ready_tasks = self._task_views_for_status(TaskStatus.READY)
|
| 322 |
+
running_tasks = self._task_views_for_status(TaskStatus.RUNNING)
|
| 323 |
+
completed_tasks = self._task_views_for_status(TaskStatus.COMPLETED)
|
| 324 |
+
blocked_tasks = self._task_views_for_status(TaskStatus.BLOCKED)
|
| 325 |
+
effective_workers = self._effective_worker_capacity(env_state)
|
| 326 |
+
return WorkflowArenaObservation(
|
| 327 |
+
done=done,
|
| 328 |
+
reward=reward,
|
| 329 |
+
config=self._config,
|
| 330 |
+
current_time=env_state.current_time,
|
| 331 |
+
total_workers=self._config.worker_count,
|
| 332 |
+
effective_workers=effective_workers,
|
| 333 |
+
degraded_workers=env_state.degraded_workers,
|
| 334 |
+
free_workers=max(0, effective_workers - len(running_tasks)),
|
| 335 |
+
time_budget=env_state.time_budget,
|
| 336 |
+
time_remaining=self._time_remaining(env_state),
|
| 337 |
+
progress=ProgressSummary(
|
| 338 |
+
total=len(episode.tasks),
|
| 339 |
+
blocked=len(blocked_tasks),
|
| 340 |
+
ready=len(ready_tasks),
|
| 341 |
+
running=len(running_tasks),
|
| 342 |
+
completed=len(completed_tasks),
|
| 343 |
+
),
|
| 344 |
+
ready_tasks=ready_tasks,
|
| 345 |
+
running_tasks=running_tasks,
|
| 346 |
+
completed_tasks=completed_tasks,
|
| 347 |
+
blocked_tasks=blocked_tasks,
|
| 348 |
+
last_reward_breakdown=breakdown,
|
| 349 |
+
cumulative_reward=self._cumulative_reward,
|
| 350 |
+
success_metrics=self._success_metrics(
|
| 351 |
+
benchmark_score_override=benchmark_score_override
|
| 352 |
+
),
|
| 353 |
+
note=note,
|
| 354 |
+
benchmark_score=benchmark_score_override,
|
| 355 |
+
recent_failure_events=env_state.recent_failure_events,
|
| 356 |
+
metadata={
|
| 357 |
+
"phase": "simulation_active",
|
| 358 |
+
"note": note,
|
| 359 |
+
"effective_workers": effective_workers,
|
| 360 |
+
"degraded_workers": env_state.degraded_workers,
|
| 361 |
+
"time_budget": env_state.time_budget,
|
| 362 |
+
"time_remaining": self._time_remaining(env_state),
|
| 363 |
+
"recent_failure_events": [
|
| 364 |
+
event.model_dump(mode="json") for event in env_state.recent_failure_events
|
| 365 |
+
],
|
| 366 |
+
"episode_loop": [
|
| 367 |
+
"reset generates a seeded workflow DAG episode",
|
| 368 |
+
"dispatch(task_ids=[...]) starts ready tasks if workers are free",
|
| 369 |
+
"wait() advances simulated time to the next completion event",
|
| 370 |
+
"medium and hard episodes may end at a fixed time budget",
|
| 371 |
+
"hard mode may trigger outages and retry failures",
|
| 372 |
+
],
|
| 373 |
+
},
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def _lower_bound_makespan(self, episode: WorkflowEpisodeSpec) -> int:
|
| 377 |
+
total_work = sum(task.duration for task in episode.tasks)
|
| 378 |
+
work_bound = (total_work + self._config.worker_count - 1) // self._config.worker_count
|
| 379 |
+
path_bound = max(task.critical_path_length for task in episode.tasks)
|
| 380 |
+
return max(1, work_bound, path_bound)
|
| 381 |
+
|
| 382 |
+
def _termination_breakdown(
|
| 383 |
+
self,
|
| 384 |
+
*,
|
| 385 |
+
invalid_penalty: float = 0.0,
|
| 386 |
+
idle_penalty: float = 0.0,
|
| 387 |
+
terminal_makespan_score: float = 0.0,
|
| 388 |
+
unfinished_task_penalty: float = 0.0,
|
| 389 |
+
) -> RewardBreakdown:
|
| 390 |
+
return RewardBreakdown(
|
| 391 |
+
invalid_action_penalty=round(invalid_penalty, 4),
|
| 392 |
+
idle_penalty=round(idle_penalty, 4),
|
| 393 |
+
terminal_makespan_score=round(terminal_makespan_score, 4),
|
| 394 |
+
unfinished_task_penalty=round(unfinished_task_penalty, 4),
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
def _terminate_episode(
|
| 398 |
+
self,
|
| 399 |
+
*,
|
| 400 |
+
note: str,
|
| 401 |
+
breakdown: RewardBreakdown,
|
| 402 |
+
reward: float,
|
| 403 |
+
reason: str,
|
| 404 |
+
benchmark_score: float | None = None,
|
| 405 |
+
) -> WorkflowArenaObservation:
|
| 406 |
+
if benchmark_score is None:
|
| 407 |
+
benchmark_score = self._benchmark_score()
|
| 408 |
+
self._cumulative_reward += reward
|
| 409 |
+
observation = self._base_observation(
|
| 410 |
+
reward=reward,
|
| 411 |
+
breakdown=breakdown,
|
| 412 |
+
note=note,
|
| 413 |
+
done=True,
|
| 414 |
+
benchmark_score_override=benchmark_score,
|
| 415 |
+
)
|
| 416 |
+
observation.termination_reason = reason
|
| 417 |
+
observation.benchmark_score = benchmark_score
|
| 418 |
+
observation.metadata["termination_reason"] = reason
|
| 419 |
+
observation.metadata["benchmark_score"] = benchmark_score
|
| 420 |
+
return observation
|
| 421 |
+
|
| 422 |
+
def _step_limit_reached(self) -> bool:
|
| 423 |
+
return self._state.step_count >= self._max_episode_steps
|
| 424 |
+
|
| 425 |
+
def _maybe_terminate_for_limits(self) -> WorkflowArenaObservation | None:
|
| 426 |
+
if not self._step_limit_reached():
|
| 427 |
+
return None
|
| 428 |
+
_, env_state = self._require_episode()
|
| 429 |
+
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
|
| 430 |
+
terminal_score = self._terminal_score()
|
| 431 |
+
breakdown = self._termination_breakdown(
|
| 432 |
+
terminal_makespan_score=terminal_score,
|
| 433 |
+
unfinished_task_penalty=unfinished_penalty,
|
| 434 |
+
)
|
| 435 |
+
reward = round(-1.0 + unfinished_penalty + terminal_score, 4)
|
| 436 |
+
return self._terminate_episode(
|
| 437 |
+
note="Episode terminated after hitting the safety step limit.",
|
| 438 |
+
breakdown=breakdown,
|
| 439 |
+
reward=reward,
|
| 440 |
+
reason="step_limit",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def _apply_invalid(
|
| 444 |
+
self,
|
| 445 |
+
message: str,
|
| 446 |
+
*,
|
| 447 |
+
penalty: float | None = None,
|
| 448 |
+
) -> WorkflowArenaObservation:
|
| 449 |
+
_, env_state = self._require_episode()
|
| 450 |
+
applied_penalty = (
|
| 451 |
+
self.INVALID_ACTION_PENALTY if penalty is None else float(penalty)
|
| 452 |
+
)
|
| 453 |
+
breakdown = RewardBreakdown(invalid_action_penalty=round(applied_penalty, 4))
|
| 454 |
+
self._cumulative_reward += breakdown.invalid_action_penalty
|
| 455 |
+
self._set_recent_failure_events(env_state, [])
|
| 456 |
+
observation = self._base_observation(
|
| 457 |
+
reward=breakdown.invalid_action_penalty,
|
| 458 |
+
breakdown=breakdown,
|
| 459 |
+
note="Invalid action.",
|
| 460 |
+
done=False,
|
| 461 |
+
)
|
| 462 |
+
observation.validation_error = message
|
| 463 |
+
observation.metadata["validation_error"] = message
|
| 464 |
+
return observation
|
| 465 |
+
|
| 466 |
+
def _transition_unlocks(self, completed_task_ids: list[str]) -> list[str]:
|
| 467 |
+
episode, env_state = self._require_episode()
|
| 468 |
+
task_map = {task.task_id: task for task in episode.tasks}
|
| 469 |
+
unlocked: list[str] = []
|
| 470 |
+
for task_id in completed_task_ids:
|
| 471 |
+
for dependent_id in task_map[task_id].dependents:
|
| 472 |
+
env_state.task_remaining_dependencies[dependent_id] -= 1
|
| 473 |
+
if env_state.task_remaining_dependencies[dependent_id] == 0:
|
| 474 |
+
env_state.task_statuses[dependent_id] = TaskStatus.READY
|
| 475 |
+
if dependent_id not in env_state.ready_task_ids:
|
| 476 |
+
env_state.ready_task_ids.append(dependent_id)
|
| 477 |
+
if dependent_id in env_state.blocked_task_ids:
|
| 478 |
+
env_state.blocked_task_ids.remove(dependent_id)
|
| 479 |
+
unlocked.append(dependent_id)
|
| 480 |
+
env_state.ready_task_ids.sort()
|
| 481 |
+
env_state.blocked_task_ids.sort()
|
| 482 |
+
return unlocked
|
| 483 |
+
|
| 484 |
+
def reset(
|
| 485 |
+
self,
|
| 486 |
+
seed: int | None = None,
|
| 487 |
+
episode_id: str | None = None,
|
| 488 |
+
**kwargs: Any,
|
| 489 |
+
) -> WorkflowArenaObservation:
|
| 490 |
+
"""Generate a seeded workflow DAG episode."""
|
| 491 |
+
|
| 492 |
+
preset_raw = kwargs.pop("preset", DifficultyPreset.EASY)
|
| 493 |
+
worker_count_raw = kwargs.pop("worker_count", None)
|
| 494 |
+
del kwargs
|
| 495 |
+
preset = (
|
| 496 |
+
preset_raw
|
| 497 |
+
if isinstance(preset_raw, DifficultyPreset)
|
| 498 |
+
else DifficultyPreset(str(preset_raw))
|
| 499 |
+
)
|
| 500 |
+
preset_config = get_preset_config(preset)
|
| 501 |
+
chosen_seed = 0 if seed is None else seed
|
| 502 |
+
chosen_worker_count = (
|
| 503 |
+
preset_config.worker_count
|
| 504 |
+
if worker_count_raw is None
|
| 505 |
+
else int(worker_count_raw)
|
| 506 |
+
)
|
| 507 |
+
chosen_episode_id = str(uuid4()) if episode_id is None else episode_id
|
| 508 |
+
self._state = State(episode_id=chosen_episode_id, step_count=0)
|
| 509 |
+
self._cumulative_reward = 0.0
|
| 510 |
+
self._config = EpisodeConfig(
|
| 511 |
+
preset=preset,
|
| 512 |
+
seed=chosen_seed,
|
| 513 |
+
worker_count=chosen_worker_count,
|
| 514 |
+
)
|
| 515 |
+
self._event_rng = random.Random(
|
| 516 |
+
(chosen_seed + 1) * 1009
|
| 517 |
+
+ (chosen_worker_count * 131)
|
| 518 |
+
+ (list(DifficultyPreset).index(preset) + 1)
|
| 519 |
+
)
|
| 520 |
+
self._episode_spec, self._env_state = generate_episode(self._config)
|
| 521 |
+
self._max_episode_steps = max(
|
| 522 |
+
self.STEP_LIMIT_FLOOR,
|
| 523 |
+
len(self._episode_spec.tasks) * self.STEP_LIMIT_MULTIPLIER,
|
| 524 |
+
)
|
| 525 |
+
self._env_state.episode_id = chosen_episode_id
|
| 526 |
+
lower_bound = self._lower_bound_makespan(self._episode_spec)
|
| 527 |
+
if preset_config.time_budget_multiplier is not None:
|
| 528 |
+
self._env_state.time_budget = int(
|
| 529 |
+
math.ceil(lower_bound * preset_config.time_budget_multiplier)
|
| 530 |
+
)
|
| 531 |
+
self._set_recent_failure_events(self._env_state, [])
|
| 532 |
+
|
| 533 |
+
note = "Workflow episode generated. Dispatch ready tasks or wait for completions."
|
| 534 |
+
if self._env_state.time_budget is not None:
|
| 535 |
+
note = (
|
| 536 |
+
f"Workflow episode generated. Finish as much as possible before "
|
| 537 |
+
f"t={self._env_state.time_budget}."
|
| 538 |
+
)
|
| 539 |
+
if preset == DifficultyPreset.HARD:
|
| 540 |
+
note += " Hard mode may trigger worker outages and retry failures."
|
| 541 |
+
return self._base_observation(
|
| 542 |
+
reward=0.0,
|
| 543 |
+
breakdown=RewardBreakdown(),
|
| 544 |
+
note=note,
|
| 545 |
+
done=False,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
def _wait_note(
|
| 549 |
+
self,
|
| 550 |
+
*,
|
| 551 |
+
completed_now: list[str],
|
| 552 |
+
failed_now: list[str],
|
| 553 |
+
unlocked: list[str],
|
| 554 |
+
recent_events: list[WorkflowFailureEvent],
|
| 555 |
+
time_budget_hit: bool = False,
|
| 556 |
+
) -> str:
|
| 557 |
+
chunks: list[str] = []
|
| 558 |
+
if time_budget_hit:
|
| 559 |
+
chunks.append("Time budget exhausted before the next completion event.")
|
| 560 |
+
elif completed_now:
|
| 561 |
+
chunks.append(f"Completed: {', '.join(completed_now)}.")
|
| 562 |
+
else:
|
| 563 |
+
chunks.append("Advanced to next completion event.")
|
| 564 |
+
if failed_now:
|
| 565 |
+
chunks.append(f"Retry required: {', '.join(failed_now)}.")
|
| 566 |
+
if unlocked:
|
| 567 |
+
chunks.append(f"Unlocked: {', '.join(unlocked)}.")
|
| 568 |
+
for event in recent_events:
|
| 569 |
+
if event.event_type == FailureEventType.WORKER_OUTAGE_START:
|
| 570 |
+
chunks.append(event.detail)
|
| 571 |
+
elif event.event_type == FailureEventType.WORKER_OUTAGE_END:
|
| 572 |
+
chunks.append("Worker capacity restored.")
|
| 573 |
+
return " ".join(chunks)
|
| 574 |
+
|
| 575 |
+
def step(
|
| 576 |
+
self,
|
| 577 |
+
action: WorkflowArenaAction,
|
| 578 |
+
timeout_s: float | None = None,
|
| 579 |
+
**kwargs: Any,
|
| 580 |
+
) -> WorkflowArenaObservation:
|
| 581 |
+
"""Apply a dispatch or wait action using event-driven semantics."""
|
| 582 |
+
|
| 583 |
+
del timeout_s, kwargs
|
| 584 |
+
episode, env_state = self._require_episode()
|
| 585 |
+
task_map = {task.task_id: task for task in episode.tasks}
|
| 586 |
+
self._state.step_count += 1
|
| 587 |
+
self._set_recent_failure_events(env_state, [])
|
| 588 |
+
|
| 589 |
+
limit_termination = self._maybe_terminate_for_limits()
|
| 590 |
+
if limit_termination is not None:
|
| 591 |
+
return limit_termination
|
| 592 |
+
|
| 593 |
+
if action.action_type == WorkflowActionType.WAIT and action.task_ids:
|
| 594 |
+
return self._apply_invalid("wait() must not include task_ids.")
|
| 595 |
+
|
| 596 |
+
if action.action_type == WorkflowActionType.DISPATCH:
|
| 597 |
+
if not action.task_ids:
|
| 598 |
+
return self._apply_invalid(
|
| 599 |
+
"dispatch(task_ids=[...]) requires at least one task id."
|
| 600 |
+
)
|
| 601 |
+
if len(set(action.task_ids)) != len(action.task_ids):
|
| 602 |
+
return self._apply_invalid(
|
| 603 |
+
"dispatch(task_ids=[...]) must not contain duplicate task ids."
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
free_workers = self._effective_worker_capacity(env_state) - len(
|
| 607 |
+
env_state.running_task_ids
|
| 608 |
+
)
|
| 609 |
+
if len(action.task_ids) > max(0, free_workers):
|
| 610 |
+
return self._apply_invalid(
|
| 611 |
+
"dispatch(task_ids=[...]) cannot exceed available worker capacity.",
|
| 612 |
+
penalty=self.OVERCAPACITY_INVALID_ACTION_PENALTY,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
unknown_tasks = [task_id for task_id in action.task_ids if task_id not in task_map]
|
| 616 |
+
if unknown_tasks:
|
| 617 |
+
return self._apply_invalid(f"Unknown task ids: {unknown_tasks}.")
|
| 618 |
+
|
| 619 |
+
not_ready = [
|
| 620 |
+
task_id
|
| 621 |
+
for task_id in action.task_ids
|
| 622 |
+
if env_state.task_statuses[task_id] != TaskStatus.READY
|
| 623 |
+
]
|
| 624 |
+
if not_ready:
|
| 625 |
+
return self._apply_invalid(
|
| 626 |
+
f"Only ready tasks can be dispatched: {not_ready}."
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
prev_utilization_potential, prev_criticality_potential = self._dispatch_potential(
|
| 630 |
+
env_state, task_map
|
| 631 |
+
)
|
| 632 |
+
for task_id in action.task_ids:
|
| 633 |
+
task = task_map[task_id]
|
| 634 |
+
env_state.task_statuses[task_id] = TaskStatus.RUNNING
|
| 635 |
+
env_state.task_start_times[task_id] = env_state.current_time
|
| 636 |
+
env_state.task_assigned_finish_times[task_id] = (
|
| 637 |
+
env_state.current_time + task.duration
|
| 638 |
+
)
|
| 639 |
+
env_state.running_task_ids.append(task_id)
|
| 640 |
+
env_state.ready_task_ids.remove(task_id)
|
| 641 |
+
env_state.running_task_ids.sort()
|
| 642 |
+
next_utilization_potential, next_criticality_potential = self._dispatch_potential(
|
| 643 |
+
env_state, task_map
|
| 644 |
+
)
|
| 645 |
+
breakdown = RewardBreakdown(
|
| 646 |
+
utilization_reward=round(
|
| 647 |
+
next_utilization_potential - prev_utilization_potential, 4
|
| 648 |
+
),
|
| 649 |
+
criticality_reward=round(
|
| 650 |
+
next_criticality_potential - prev_criticality_potential, 4
|
| 651 |
+
),
|
| 652 |
+
)
|
| 653 |
+
reward = round(
|
| 654 |
+
breakdown.utilization_reward + breakdown.criticality_reward,
|
| 655 |
+
4,
|
| 656 |
+
)
|
| 657 |
+
self._cumulative_reward += reward
|
| 658 |
+
observation = self._base_observation(
|
| 659 |
+
reward=reward,
|
| 660 |
+
breakdown=breakdown,
|
| 661 |
+
note="Tasks dispatched. Use wait() to advance to the next completion event.",
|
| 662 |
+
done=False,
|
| 663 |
+
)
|
| 664 |
+
observation.received_action = action.model_dump(mode="json")
|
| 665 |
+
observation.metadata["received_action"] = action.model_dump(mode="json")
|
| 666 |
+
return observation
|
| 667 |
+
|
| 668 |
+
if not env_state.running_task_ids:
|
| 669 |
+
return self._apply_invalid("wait() requires at least one running task.")
|
| 670 |
+
|
| 671 |
+
recent_events: list[WorkflowFailureEvent] = []
|
| 672 |
+
avoidable_wait_penalty = 0.0
|
| 673 |
+
if env_state.ready_task_ids:
|
| 674 |
+
free_workers = self._effective_worker_capacity(env_state) - len(
|
| 675 |
+
env_state.running_task_ids
|
| 676 |
+
)
|
| 677 |
+
if free_workers > 0:
|
| 678 |
+
avoidable_wait_penalty = self.AVOIDABLE_WAIT_PENALTY_PER_SLOT * min(
|
| 679 |
+
free_workers,
|
| 680 |
+
len(env_state.ready_task_ids),
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
self._maybe_start_worker_outage(env_state, recent_events)
|
| 684 |
+
|
| 685 |
+
next_completion_time = min(
|
| 686 |
+
env_state.task_assigned_finish_times[task_id]
|
| 687 |
+
for task_id in env_state.running_task_ids
|
| 688 |
+
)
|
| 689 |
+
target_time = next_completion_time
|
| 690 |
+
budget_hit_before_completion = False
|
| 691 |
+
if env_state.time_budget is not None and env_state.time_budget < next_completion_time:
|
| 692 |
+
target_time = env_state.time_budget
|
| 693 |
+
budget_hit_before_completion = True
|
| 694 |
+
|
| 695 |
+
elapsed = target_time - env_state.current_time
|
| 696 |
+
env_state.cumulative_busy_time += elapsed * len(env_state.running_task_ids)
|
| 697 |
+
env_state.current_time = target_time
|
| 698 |
+
self._maybe_end_worker_outage(env_state, recent_events)
|
| 699 |
+
|
| 700 |
+
if budget_hit_before_completion:
|
| 701 |
+
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
|
| 702 |
+
terminal_score = self._terminal_score()
|
| 703 |
+
breakdown = RewardBreakdown(
|
| 704 |
+
idle_penalty=round(avoidable_wait_penalty, 4),
|
| 705 |
+
terminal_makespan_score=round(terminal_score, 4),
|
| 706 |
+
unfinished_task_penalty=round(unfinished_penalty, 4),
|
| 707 |
+
)
|
| 708 |
+
reward = round(
|
| 709 |
+
breakdown.idle_penalty
|
| 710 |
+
+ breakdown.terminal_makespan_score
|
| 711 |
+
+ breakdown.unfinished_task_penalty,
|
| 712 |
+
4,
|
| 713 |
+
)
|
| 714 |
+
self._set_recent_failure_events(env_state, recent_events)
|
| 715 |
+
note = self._wait_note(
|
| 716 |
+
completed_now=[],
|
| 717 |
+
failed_now=[],
|
| 718 |
+
unlocked=[],
|
| 719 |
+
recent_events=recent_events,
|
| 720 |
+
time_budget_hit=True,
|
| 721 |
+
)
|
| 722 |
+
observation = self._terminate_episode(
|
| 723 |
+
note=note,
|
| 724 |
+
breakdown=breakdown,
|
| 725 |
+
reward=reward,
|
| 726 |
+
reason="time_budget",
|
| 727 |
+
)
|
| 728 |
+
observation.received_action = action.model_dump(mode="json")
|
| 729 |
+
observation.metadata["received_action"] = action.model_dump(mode="json")
|
| 730 |
+
return observation
|
| 731 |
+
|
| 732 |
+
completed_candidates = sorted(
|
| 733 |
+
[
|
| 734 |
+
task_id
|
| 735 |
+
for task_id in env_state.running_task_ids
|
| 736 |
+
if env_state.task_assigned_finish_times[task_id] == next_completion_time
|
| 737 |
+
]
|
| 738 |
+
)
|
| 739 |
+
completed_now: list[str] = []
|
| 740 |
+
failed_now: list[str] = []
|
| 741 |
+
for task_id in completed_candidates:
|
| 742 |
+
env_state.running_task_ids.remove(task_id)
|
| 743 |
+
del env_state.task_assigned_finish_times[task_id]
|
| 744 |
+
if self._should_retry_fail(task_id):
|
| 745 |
+
env_state.task_attempt_counts[task_id] += 1
|
| 746 |
+
env_state.task_statuses[task_id] = TaskStatus.READY
|
| 747 |
+
env_state.task_start_times.pop(task_id, None)
|
| 748 |
+
env_state.task_end_times.pop(task_id, None)
|
| 749 |
+
if task_id not in env_state.ready_task_ids:
|
| 750 |
+
env_state.ready_task_ids.append(task_id)
|
| 751 |
+
failed_now.append(task_id)
|
| 752 |
+
recent_events.append(
|
| 753 |
+
WorkflowFailureEvent(
|
| 754 |
+
event_type=FailureEventType.TASK_RETRY_FAILURE,
|
| 755 |
+
time=next_completion_time,
|
| 756 |
+
task_id=task_id,
|
| 757 |
+
detail=f"{task_id} failed and returned to ready.",
|
| 758 |
+
)
|
| 759 |
+
)
|
| 760 |
+
else:
|
| 761 |
+
env_state.task_statuses[task_id] = TaskStatus.COMPLETED
|
| 762 |
+
env_state.task_end_times[task_id] = next_completion_time
|
| 763 |
+
env_state.completed_task_ids.append(task_id)
|
| 764 |
+
completed_now.append(task_id)
|
| 765 |
+
|
| 766 |
+
env_state.completed_task_ids.sort()
|
| 767 |
+
env_state.ready_task_ids.sort()
|
| 768 |
+
unlocked = self._transition_unlocks(completed_now)
|
| 769 |
+
|
| 770 |
+
completion_reward = sum(
|
| 771 |
+
0.04 + 0.01 * task_map[task_id].priority for task_id in completed_now
|
| 772 |
+
)
|
| 773 |
+
deadline_reward = 0.0
|
| 774 |
+
criticality_reward = 0.0
|
| 775 |
+
for task_id in completed_now:
|
| 776 |
+
task = task_map[task_id]
|
| 777 |
+
if task.deadline is not None:
|
| 778 |
+
lateness = next_completion_time - task.deadline
|
| 779 |
+
deadline_reward += 0.05 if lateness <= 0 else -0.02 * lateness
|
| 780 |
+
criticality_reward += 0.03 * task.criticality
|
| 781 |
+
|
| 782 |
+
utilization_reward = 0.06 * (
|
| 783 |
+
elapsed
|
| 784 |
+
* (len(completed_candidates) + len(env_state.running_task_ids))
|
| 785 |
+
/ max(1, self._config.worker_count)
|
| 786 |
+
)
|
| 787 |
+
idle_penalty = 0.0
|
| 788 |
+
if not env_state.running_task_ids and env_state.ready_task_ids:
|
| 789 |
+
idle_penalty = -0.03 * len(env_state.ready_task_ids)
|
| 790 |
+
|
| 791 |
+
done = len(env_state.completed_task_ids) == len(episode.tasks)
|
| 792 |
+
breakdown = RewardBreakdown(
|
| 793 |
+
completion_reward=round(completion_reward, 4),
|
| 794 |
+
utilization_reward=round(utilization_reward, 4),
|
| 795 |
+
deadline_reward=round(deadline_reward, 4),
|
| 796 |
+
criticality_reward=round(criticality_reward, 4),
|
| 797 |
+
idle_penalty=round(idle_penalty + avoidable_wait_penalty, 4),
|
| 798 |
+
terminal_makespan_score=round(self._terminal_score() if done else 0.0, 4),
|
| 799 |
+
)
|
| 800 |
+
reward = round(
|
| 801 |
+
breakdown.completion_reward
|
| 802 |
+
+ breakdown.utilization_reward
|
| 803 |
+
+ breakdown.deadline_reward
|
| 804 |
+
+ breakdown.criticality_reward
|
| 805 |
+
+ breakdown.idle_penalty
|
| 806 |
+
+ breakdown.terminal_makespan_score,
|
| 807 |
+
4,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
budget_exhausted_now = (
|
| 811 |
+
not done
|
| 812 |
+
and env_state.time_budget is not None
|
| 813 |
+
and env_state.current_time >= env_state.time_budget
|
| 814 |
+
)
|
| 815 |
+
if budget_exhausted_now:
|
| 816 |
+
unfinished_penalty = self._unfinished_task_penalty(env_state.current_time)
|
| 817 |
+
breakdown.unfinished_task_penalty = round(unfinished_penalty, 4)
|
| 818 |
+
breakdown.terminal_makespan_score = round(self._terminal_score(), 4)
|
| 819 |
+
reward = round(
|
| 820 |
+
reward
|
| 821 |
+
+ breakdown.unfinished_task_penalty
|
| 822 |
+
+ breakdown.terminal_makespan_score,
|
| 823 |
+
4,
|
| 824 |
+
)
|
| 825 |
+
self._set_recent_failure_events(env_state, recent_events)
|
| 826 |
+
note = self._wait_note(
|
| 827 |
+
completed_now=completed_now,
|
| 828 |
+
failed_now=failed_now,
|
| 829 |
+
unlocked=unlocked,
|
| 830 |
+
recent_events=recent_events,
|
| 831 |
+
time_budget_hit=True,
|
| 832 |
+
)
|
| 833 |
+
observation = self._terminate_episode(
|
| 834 |
+
note=note,
|
| 835 |
+
breakdown=breakdown,
|
| 836 |
+
reward=reward,
|
| 837 |
+
reason="time_budget",
|
| 838 |
+
)
|
| 839 |
+
observation.received_action = action.model_dump(mode="json")
|
| 840 |
+
observation.metadata["received_action"] = action.model_dump(mode="json")
|
| 841 |
+
observation.metadata["completed_now"] = completed_now
|
| 842 |
+
observation.metadata["unlocked_now"] = unlocked
|
| 843 |
+
observation.metadata["failed_now"] = failed_now
|
| 844 |
+
return observation
|
| 845 |
+
|
| 846 |
+
self._cumulative_reward += reward
|
| 847 |
+
self._set_recent_failure_events(env_state, recent_events)
|
| 848 |
+
observation = self._base_observation(
|
| 849 |
+
reward=reward,
|
| 850 |
+
breakdown=breakdown,
|
| 851 |
+
note=self._wait_note(
|
| 852 |
+
completed_now=completed_now,
|
| 853 |
+
failed_now=failed_now,
|
| 854 |
+
unlocked=unlocked,
|
| 855 |
+
recent_events=recent_events,
|
| 856 |
+
),
|
| 857 |
+
done=done,
|
| 858 |
+
benchmark_score_override=self._benchmark_score() if done else None,
|
| 859 |
+
)
|
| 860 |
+
observation.received_action = action.model_dump(mode="json")
|
| 861 |
+
observation.metadata["received_action"] = action.model_dump(mode="json")
|
| 862 |
+
observation.metadata["completed_now"] = completed_now
|
| 863 |
+
observation.metadata["unlocked_now"] = unlocked
|
| 864 |
+
observation.metadata["failed_now"] = failed_now
|
| 865 |
+
if done:
|
| 866 |
+
observation.benchmark_score = observation.success_metrics.benchmark_score
|
| 867 |
+
return observation
|
| 868 |
+
|
| 869 |
+
@property
|
| 870 |
+
def state(self) -> State:
|
| 871 |
+
"""Expose generic OpenEnv state metadata."""
|
| 872 |
+
|
| 873 |
+
return self._state
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|