Spaces:
Running
Running
Final compliance check for MetaXScalar: Remove underscores from task IDs, update inference logging, and ensure score clipping
010c635 | """ | |
| Multi-task configuration for the OpenEnv bus routing environment. | |
| Three difficulty tiers — Easy, Medium, Hard — share the same | |
| ``BusRoutingEnv`` class but differ in the number of stops, passenger | |
| demand, fuel constraints, and penalty intensity. | |
| Now expanded to include 10 subtasks for each difficulty level (30 tasks total). | |
| """ | |
| from __future__ import annotations | |
| import copy | |
| from dataclasses import dataclass | |
| from typing import Any, Dict | |
| from environment import BusRoutingEnv | |
| # Explicitly export task configurations for OpenEnv detection | |
| __all__ = [ | |
| "TaskConfig", | |
| "TASKS", | |
| "TASK_EASY", | |
| "TASK_MEDIUM", | |
| "TASK_HARD", | |
| "get_task", | |
| ] | |
| class TaskConfig: | |
| """All parameters needed to instantiate a BusRoutingEnv for a task.""" | |
| name: str = "" | |
| description: str = "" | |
| difficulty: str = "medium" # easy | medium | hard | |
| num_stops: int = 10 | |
| num_buses: int = 1 | |
| max_steps: int = 150 | |
| seed: int = 42 | |
| bus_capacity: int = 30 | |
| fuel_start: float = 100.0 | |
| passenger_arrival_rate: float = 1.2 | |
| large_queue_threshold: int = 10 | |
| wait_time_threshold: int = 3 | |
| fuel_cost_move: float = 1.0 | |
| fuel_cost_wait: float = 0.2 | |
| background_bus_pickup_fraction: float = 0.6 | |
| new_stop_bonus: float = 1.0 | |
| idle_camping_penalty: float = 0.6 | |
| camping_grace_steps: int = 1 | |
| nearby_queue_ignore_penalty: float = 1.5 | |
| recent_window: int = 10 | |
| recent_unvisited_bonus: float = 1.0 | |
| repeat_stop_penalty: float = 0.5 | |
| high_queue_reward_threshold: int = 6 | |
| high_queue_visit_bonus: float = 2.0 | |
| reward_clip: float = 10.0 | |
| demand_profile: str = "synthetic" | |
| def build_env(self) -> BusRoutingEnv: | |
| import os | |
| m_steps = int(os.getenv("EVAL_MAX_STEPS", self.max_steps)) | |
| return BusRoutingEnv( | |
| num_stops=self.num_stops, | |
| num_buses=self.num_buses, | |
| max_steps=m_steps, | |
| seed=self.seed, | |
| bus_capacity=self.bus_capacity, | |
| fuel_start=self.fuel_start, | |
| passenger_arrival_rate=self.passenger_arrival_rate, | |
| large_queue_threshold=self.large_queue_threshold, | |
| wait_time_threshold=self.wait_time_threshold, | |
| fuel_cost_move=self.fuel_cost_move, | |
| fuel_cost_wait=self.fuel_cost_wait, | |
| background_bus_pickup_fraction=self.background_bus_pickup_fraction, | |
| new_stop_bonus=self.new_stop_bonus, | |
| idle_camping_penalty=self.idle_camping_penalty, | |
| camping_grace_steps=self.camping_grace_steps, | |
| nearby_queue_ignore_penalty=self.nearby_queue_ignore_penalty, | |
| recent_window=self.recent_window, | |
| recent_unvisited_bonus=self.recent_unvisited_bonus, | |
| repeat_stop_penalty=self.repeat_stop_penalty, | |
| high_queue_reward_threshold=self.high_queue_reward_threshold, | |
| high_queue_visit_bonus=self.high_queue_visit_bonus, | |
| reward_clip=self.reward_clip, | |
| demand_profile=self.demand_profile, | |
| ) | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "name": self.name, | |
| "difficulty": self.difficulty, | |
| "description": self.description, | |
| "num_stops": self.num_stops, | |
| "num_buses": self.num_buses, | |
| "max_steps": self.max_steps, | |
| "fuel_start": self.fuel_start, | |
| "passenger_arrival_rate": self.passenger_arrival_rate, | |
| "fuel_cost_move": self.fuel_cost_move, | |
| "fuel_cost_wait": self.fuel_cost_wait, | |
| "large_queue_threshold": self.large_queue_threshold, | |
| "bus_capacity": self.bus_capacity, | |
| } | |
| _TASK_EASY_TEMPLATE = TaskConfig( | |
| difficulty="easy", | |
| num_stops=5, | |
| num_buses=1, | |
| max_steps=100, | |
| seed=42, | |
| bus_capacity=30, | |
| fuel_start=100.0, | |
| passenger_arrival_rate=0.6, | |
| large_queue_threshold=12, | |
| wait_time_threshold=5, | |
| fuel_cost_move=0.5, | |
| fuel_cost_wait=0.1, | |
| new_stop_bonus=0.5, | |
| idle_camping_penalty=0.3, | |
| nearby_queue_ignore_penalty=0.5, | |
| repeat_stop_penalty=0.2, | |
| high_queue_reward_threshold=8, | |
| reward_clip=10.0, | |
| demand_profile="off_peak", | |
| ) | |
| _TASK_MEDIUM_TEMPLATE = TaskConfig( | |
| difficulty="medium", | |
| num_stops=10, | |
| num_buses=1, | |
| max_steps=150, | |
| seed=42, | |
| bus_capacity=30, | |
| fuel_start=100.0, | |
| passenger_arrival_rate=1.2, | |
| large_queue_threshold=10, | |
| wait_time_threshold=3, | |
| fuel_cost_move=1.0, | |
| fuel_cost_wait=0.2, | |
| new_stop_bonus=1.0, | |
| idle_camping_penalty=0.6, | |
| nearby_queue_ignore_penalty=1.5, | |
| repeat_stop_penalty=0.5, | |
| high_queue_reward_threshold=6, | |
| reward_clip=10.0, | |
| demand_profile="weekday", | |
| ) | |
| _TASK_HARD_TEMPLATE = TaskConfig( | |
| difficulty="hard", | |
| num_stops=12, | |
| num_buses=2, | |
| max_steps=200, | |
| seed=42, | |
| bus_capacity=25, | |
| fuel_start=80.0, | |
| passenger_arrival_rate=2.0, | |
| large_queue_threshold=8, | |
| wait_time_threshold=2, | |
| fuel_cost_move=1.5, | |
| fuel_cost_wait=0.4, | |
| new_stop_bonus=1.5, | |
| idle_camping_penalty=1.0, | |
| camping_grace_steps=0, | |
| nearby_queue_ignore_penalty=2.5, | |
| repeat_stop_penalty=0.8, | |
| high_queue_reward_threshold=5, | |
| high_queue_visit_bonus=3.0, | |
| reward_clip=15.0, | |
| demand_profile="peak_hour", | |
| ) | |
| TASKS: Dict[str, TaskConfig] = {} | |
| # Generate 10 subtasks for each difficulty level (30 tasks total, no underscores) | |
| for i in range(1, 11): | |
| # Easy subtasks: task1 to task10 | |
| t1 = copy.deepcopy(_TASK_EASY_TEMPLATE) | |
| t1.name = f"task{i}" | |
| t1.description = f"Easy variant {i}" | |
| t1.seed = 42 + i | |
| t1.passenger_arrival_rate += (i - 1) * 0.05 | |
| TASKS[t1.name] = t1 | |
| globals()[t1.name] = t1 | |
| __all__.append(t1.name) | |
| # Medium subtasks: task11 to task20 | |
| t2 = copy.deepcopy(_TASK_MEDIUM_TEMPLATE) | |
| t2.name = f"task{10 + i}" | |
| t2.description = f"Medium variant {i}" | |
| t2.seed = 42 + i | |
| t2.passenger_arrival_rate += (i - 1) * 0.1 | |
| TASKS[t2.name] = t2 | |
| globals()[t2.name] = t2 | |
| __all__.append(t2.name) | |
| # Hard subtasks: task21 to task30 | |
| t3 = copy.deepcopy(_TASK_HARD_TEMPLATE) | |
| t3.name = f"task{20 + i}" | |
| t3.description = f"Hard variant {i}" | |
| t3.seed = 42 + i | |
| t3.passenger_arrival_rate += (i - 1) * 0.15 | |
| TASKS[t3.name] = t3 | |
| globals()[t3.name] = t3 | |
| __all__.append(t3.name) | |
| # Legacy aliases for compatibility | |
| TASK_EASY = TASKS["task1"] | |
| TASK_MEDIUM = TASKS["task11"] | |
| TASK_HARD = TASKS["task21"] | |
| def get_task(name: str) -> TaskConfig: | |
| key = name.lower().strip().replace("_", "") | |
| # Map legacy names to specific subtasks | |
| legacy_map = { | |
| "easy": "task1", | |
| "medium": "task11", | |
| "hard": "task21", | |
| } | |
| # Handle task1, task2, etc. directly | |
| if key in TASKS: | |
| return TASKS[key] | |
| key = legacy_map.get(key, name.lower().strip()) | |
| if key not in TASKS: | |
| raise ValueError(f"Unknown task '{name}'. Choose from: {list(TASKS.keys())}") | |
| return TASKS[key] | |