Spaces:
Runtime error
Runtime error
| """Varaha simulation types — core data structures for the wildfire logistics environment.""" | |
| import math | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| # --------------------------------------------------------------------------- | |
| # Vec3 | |
| # --------------------------------------------------------------------------- | |
| class Vec3: | |
| """Lightweight 3-component vector with basic arithmetic helpers.""" | |
| x: float = 0.0 | |
| y: float = 0.0 | |
| z: float = 0.0 | |
| # --- arithmetic --- | |
| def __add__(self, other: "Vec3") -> "Vec3": | |
| return Vec3(self.x + other.x, self.y + other.y, self.z + other.z) | |
| def __sub__(self, other: "Vec3") -> "Vec3": | |
| return Vec3(self.x - other.x, self.y - other.y, self.z - other.z) | |
| def scale(self, s: float) -> "Vec3": | |
| return Vec3(self.x * s, self.y * s, self.z * s) | |
| # --- magnitude --- | |
| def norm(self) -> float: | |
| return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2) | |
| def normalized(self) -> "Vec3": | |
| n = self.norm() | |
| if n < 1e-9: | |
| return Vec3(0.0, 0.0, 0.0) | |
| return self.scale(1.0 / n) | |
| def clamp_magnitude(self, max_mag: float) -> "Vec3": | |
| n = self.norm() | |
| if n > max_mag and n > 1e-9: | |
| return self.scale(max_mag / n) | |
| return Vec3(self.x, self.y, self.z) | |
| # --- distance --- | |
| def distance_to(self, other: "Vec3") -> float: | |
| return (self - other).norm() | |
| def horizontal_distance_to(self, other: "Vec3") -> float: | |
| dx = self.x - other.x | |
| dy = self.y - other.y | |
| return math.sqrt(dx * dx + dy * dy) | |
| # --- serialization --- | |
| def to_dict(self) -> dict[str, float]: | |
| return {"x": round(self.x, 4), "y": round(self.y, 4), "z": round(self.z, 4)} | |
| def __repr__(self) -> str: | |
| return f"Vec3({self.x:.2f}, {self.y:.2f}, {self.z:.2f})" | |
| # --------------------------------------------------------------------------- | |
| # Drone | |
| # --------------------------------------------------------------------------- | |
| class DroneState: | |
| """Full kinematic + status state of the drone.""" | |
| position: Vec3 = field(default_factory=Vec3) | |
| velocity: Vec3 = field(default_factory=Vec3) | |
| battery: float = 100.0 | |
| carrying_payload: bool = True | |
| alive: bool = True | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "position": self.position.to_dict(), | |
| "velocity": self.velocity.to_dict(), | |
| "battery": round(self.battery, 4), | |
| "carrying_payload": self.carrying_payload, | |
| "alive": self.alive, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # World entities | |
| # --------------------------------------------------------------------------- | |
| class BaseStation: | |
| """Home base where the drone launches, lands, and recharges.""" | |
| position: Vec3 = field(default_factory=Vec3) | |
| recharge_radius: float = 20.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "position": self.position.to_dict(), | |
| "recharge_radius": self.recharge_radius, | |
| } | |
| class DeliveryTarget: | |
| """A responder zone requiring supply delivery.""" | |
| id: str = "" | |
| position: Vec3 = field(default_factory=Vec3) | |
| urgency: float = 0.5 | |
| delivered: bool = False | |
| delivery_radius: float = 15.0 | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "position": self.position.to_dict(), | |
| "urgency": round(self.urgency, 4), | |
| "delivered": self.delivered, | |
| "delivery_radius": self.delivery_radius, | |
| } | |
| class HazardRegion: | |
| """Wildfire danger zone modeled as a ground-level dome. | |
| The hazard has a horizontal radius and a height. Danger is zero | |
| above ``height`` and outside ``radius``, allowing drones to fly | |
| over fires at sufficient altitude. Within the dome, danger scales | |
| with proximity to the center both horizontally and vertically. | |
| ``growth_rate`` controls per-step height increase (metres/step), | |
| simulating fire growth over an episode. | |
| """ | |
| id: str = "" | |
| center: Vec3 = field(default_factory=Vec3) | |
| radius: float = 50.0 | |
| severity: float = 0.5 | |
| height: float = 80.0 | |
| growth_rate: float = 0.0 | |
| _current_height: float = field(default=0.0, init=False, repr=False) | |
| def __post_init__(self): | |
| self._current_height = self.height | |
| def reset(self): | |
| """Reset dynamic state for a new episode.""" | |
| self._current_height = self.height | |
| def tick(self): | |
| """Advance one timestep — grow the fire.""" | |
| if self.growth_rate > 0: | |
| self._current_height += self.growth_rate | |
| def contains(self, pos: Vec3) -> bool: | |
| horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5 | |
| alt = pos.z - self.center.z | |
| return horiz <= self.radius and 0 <= alt < self._current_height | |
| def danger_factor(self, pos: Vec3) -> float: | |
| """0 outside the dome, scales up toward the ground-level center.""" | |
| horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5 | |
| if horiz >= self.radius: | |
| return 0.0 | |
| alt = pos.z - self.center.z | |
| if alt >= self._current_height or alt < 0: | |
| return 0.0 | |
| horiz_factor = 1.0 - horiz / self.radius | |
| vert_factor = 1.0 - alt / self._current_height | |
| return self.severity * horiz_factor * vert_factor | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "center": self.center.to_dict(), | |
| "radius": self.radius, | |
| "severity": self.severity, | |
| "height": self.height, | |
| "current_height": round(self._current_height, 2), | |
| "growth_rate": self.growth_rate, | |
| } | |
| class ObstacleVolume: | |
| """Axis-aligned 3D box that the drone must not enter.""" | |
| id: str = "" | |
| min_corner: Vec3 = field(default_factory=Vec3) | |
| max_corner: Vec3 = field(default_factory=Vec3) | |
| kind: str = "building" | |
| def contains(self, pos: Vec3) -> bool: | |
| return ( | |
| self.min_corner.x <= pos.x <= self.max_corner.x | |
| and self.min_corner.y <= pos.y <= self.max_corner.y | |
| and self.min_corner.z <= pos.z <= self.max_corner.z | |
| ) | |
| def center(self) -> Vec3: | |
| return Vec3( | |
| (self.min_corner.x + self.max_corner.x) / 2, | |
| (self.min_corner.y + self.max_corner.y) / 2, | |
| (self.min_corner.z + self.max_corner.z) / 2, | |
| ) | |
| def half_size(self) -> Vec3: | |
| return Vec3( | |
| (self.max_corner.x - self.min_corner.x) / 2, | |
| (self.max_corner.y - self.min_corner.y) / 2, | |
| (self.max_corner.z - self.min_corner.z) / 2, | |
| ) | |
| def height(self) -> float: | |
| return self.max_corner.z | |
| def nearest_surface_dist(self, pos: Vec3) -> float: | |
| """Signed distance to the nearest surface (negative = inside).""" | |
| cx, cy = self.center.x, self.center.y | |
| hx, hy = self.half_size.x, self.half_size.y | |
| dx = max(abs(pos.x - cx) - hx, 0.0) | |
| dy = max(abs(pos.y - cy) - hy, 0.0) | |
| dz_below = max(self.min_corner.z - pos.z, 0.0) | |
| dz_above = max(pos.z - self.max_corner.z, 0.0) | |
| return math.sqrt(dx * dx + dy * dy + (dz_below + dz_above) ** 2) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "min_corner": self.min_corner.to_dict(), | |
| "max_corner": self.max_corner.to_dict(), | |
| "kind": self.kind, | |
| } | |
| class CylindricalObstacle: | |
| """Vertical cylinder obstacle — trees, poles, pillars, tanks.""" | |
| id: str = "" | |
| center: Vec3 = field(default_factory=Vec3) | |
| radius: float = 10.0 | |
| height: float = 50.0 | |
| kind: str = "tree" | |
| def contains(self, pos: Vec3) -> bool: | |
| dx = pos.x - self.center.x | |
| dy = pos.y - self.center.y | |
| horiz_dist = math.sqrt(dx * dx + dy * dy) | |
| return horiz_dist <= self.radius and 0 <= pos.z <= self.height | |
| def nearest_surface_dist(self, pos: Vec3) -> float: | |
| dx = pos.x - self.center.x | |
| dy = pos.y - self.center.y | |
| horiz_dist = math.sqrt(dx * dx + dy * dy) | |
| radial_gap = max(horiz_dist - self.radius, 0.0) | |
| vert_gap = max(pos.z - self.height, 0.0) if pos.z > self.height else max(-pos.z, 0.0) | |
| return math.sqrt(radial_gap ** 2 + vert_gap ** 2) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "center": self.center.to_dict(), | |
| "radius": round(self.radius, 2), | |
| "height": round(self.height, 2), | |
| "kind": self.kind, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Responder units — dynamic actors that alter mission conditions mid-episode | |
| # --------------------------------------------------------------------------- | |
| RESPONDER_STATUSES = ("stable", "urgent", "critical") | |
| RESPONDER_STATUS_MAP = {"stable": 0.0, "urgent": 0.5, "critical": 1.0} | |
| INTEL_TYPES = ( | |
| "none", | |
| "blocked_north", "blocked_south", "blocked_east", "blocked_west", | |
| "safe_north", "safe_south", "safe_east", "safe_west", | |
| "fire_expanded", "fire_receded", | |
| ) | |
| INTEL_DIRECTION_VECS = { | |
| "none": (0.0, 0.0), | |
| "blocked_north": (0.0, 1.0), "blocked_south": (0.0, -1.0), | |
| "blocked_east": (1.0, 0.0), "blocked_west": (-1.0, 0.0), | |
| "safe_north": (0.0, 1.0), "safe_south": (0.0, -1.0), | |
| "safe_east": (1.0, 0.0), "safe_west": (-1.0, 0.0), | |
| "fire_expanded": (0.0, 0.0), "fire_receded": (0.0, 0.0), | |
| } | |
| class ScheduledEvent: | |
| """A future event a responder will trigger at a specific step.""" | |
| step: int = 0 | |
| event_type: str = "" | |
| payload: dict[str, Any] = field(default_factory=dict) | |
| fired: bool = False | |
| class ResponderUnit: | |
| """First responder on the ground linked to a delivery target. | |
| Can dynamically alter mission conditions mid-episode: | |
| 1. Update urgency of their linked target | |
| 2. Relocate the drop-zone (move target position) | |
| 3. Broadcast hazard intel (structured approach guidance) | |
| """ | |
| id: str = "" | |
| position: Vec3 = field(default_factory=Vec3) | |
| linked_target_id: str = "" | |
| status: str = "stable" | |
| current_need: str = "supplies" | |
| message: str = "" | |
| can_update_dropzone: bool = False | |
| active: bool = True | |
| latest_intel: str = "none" | |
| intel_severity: float = 0.0 | |
| scheduled_events: list[ScheduledEvent] = field(default_factory=list) | |
| def status_code(self) -> float: | |
| return RESPONDER_STATUS_MAP.get(self.status, 0.0) | |
| def intel_direction(self) -> tuple[float, float]: | |
| return INTEL_DIRECTION_VECS.get(self.latest_intel, (0.0, 0.0)) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "position": self.position.to_dict(), | |
| "linked_target_id": self.linked_target_id, | |
| "status": self.status, | |
| "current_need": self.current_need, | |
| "message": self.message, | |
| "can_update_dropzone": self.can_update_dropzone, | |
| "active": self.active, | |
| "latest_intel": self.latest_intel, | |
| "intel_severity": round(self.intel_severity, 4), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Observation & step diagnostics | |
| # --------------------------------------------------------------------------- | |
| class VarahaObservation: | |
| """Structured observation returned to the agent each step. | |
| Kept as a dataclass for documentation; the env also offers a plain-dict | |
| path via ``get_observation()`` for maximum serialisation flexibility. | |
| """ | |
| drone_position: Vec3 = field(default_factory=Vec3) | |
| drone_velocity: Vec3 = field(default_factory=Vec3) | |
| battery: float = 100.0 | |
| carrying_payload: bool = True | |
| alive: bool = True | |
| targets: list[dict[str, Any]] = field(default_factory=list) | |
| step: int = 0 | |
| max_steps: int = 500 | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "drone_position": self.drone_position.to_dict(), | |
| "drone_velocity": self.drone_velocity.to_dict(), | |
| "battery": round(self.battery, 4), | |
| "carrying_payload": self.carrying_payload, | |
| "alive": self.alive, | |
| "targets": self.targets, | |
| "step": self.step, | |
| "max_steps": self.max_steps, | |
| } | |
| class MissionInstruction: | |
| """Single mission instruction used for long-horizon planning mode.""" | |
| id: str = "" | |
| kind: str = "" | |
| description: str = "" | |
| target_id: str = "" | |
| tool_name: str = "" | |
| completed: bool = False | |
| violated: bool = False | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "id": self.id, | |
| "kind": self.kind, | |
| "description": self.description, | |
| "target_id": self.target_id, | |
| "tool_name": self.tool_name, | |
| "completed": self.completed, | |
| "violated": self.violated, | |
| } | |
| class TracePoint: | |
| """Single frame of the drone's recorded trajectory.""" | |
| step: int = 0 | |
| position: Vec3 = field(default_factory=Vec3) | |
| velocity: Vec3 = field(default_factory=Vec3) | |
| battery: float = 100.0 | |
| reward: float = 0.0 | |
| cumulative_reward: float = 0.0 | |
| events: list[str] = field(default_factory=list) | |
| observation: dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "step": self.step, | |
| "position": self.position.to_dict(), | |
| "velocity": self.velocity.to_dict(), | |
| "battery": round(self.battery, 4), | |
| "reward": round(self.reward, 4), | |
| "cumulative_reward": round(self.cumulative_reward, 4), | |
| "events": list(self.events), | |
| "observation": self.observation, | |
| } | |
| class StepInfo: | |
| """Per-step diagnostic info returned alongside the reward.""" | |
| collision: bool = False | |
| delivered_target_ids: list[str] = field(default_factory=list) | |
| in_hazard: bool = False | |
| hazard_severity: float = 0.0 | |
| reached_base: bool = False | |
| distance_traveled: float = 0.0 | |
| tool_call: str = "" | |
| tool_result: dict[str, Any] = field(default_factory=dict) | |
| instruction_completed: int = 0 | |
| instruction_total: int = 0 | |
| instruction_violations: int = 0 | |
| reward_breakdown: dict[str, float] = field(default_factory=dict) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "collision": self.collision, | |
| "delivered_target_ids": list(self.delivered_target_ids), | |
| "in_hazard": self.in_hazard, | |
| "hazard_severity": round(self.hazard_severity, 4), | |
| "reached_base": self.reached_base, | |
| "distance_traveled": round(self.distance_traveled, 4), | |
| "tool_call": self.tool_call, | |
| "tool_result": self.tool_result, | |
| "instruction_completed": self.instruction_completed, | |
| "instruction_total": self.instruction_total, | |
| "instruction_violations": self.instruction_violations, | |
| "reward_breakdown": { | |
| k: round(v, 4) for k, v in self.reward_breakdown.items() | |
| }, | |
| } | |