atin5551's picture
Deploy Varaha OpenEnv Docker Space
cb70a7d
"""Varaha simulation types — core data structures for the wildfire logistics environment."""
import math
from dataclasses import dataclass, field
from typing import Any
# ---------------------------------------------------------------------------
# Vec3
# ---------------------------------------------------------------------------
@dataclass
class Vec3:
"""Lightweight 3-component vector with basic arithmetic helpers."""
x: float = 0.0
y: float = 0.0
z: float = 0.0
# --- arithmetic ---
def __add__(self, other: "Vec3") -> "Vec3":
return Vec3(self.x + other.x, self.y + other.y, self.z + other.z)
def __sub__(self, other: "Vec3") -> "Vec3":
return Vec3(self.x - other.x, self.y - other.y, self.z - other.z)
def scale(self, s: float) -> "Vec3":
return Vec3(self.x * s, self.y * s, self.z * s)
# --- magnitude ---
def norm(self) -> float:
return math.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2)
def normalized(self) -> "Vec3":
n = self.norm()
if n < 1e-9:
return Vec3(0.0, 0.0, 0.0)
return self.scale(1.0 / n)
def clamp_magnitude(self, max_mag: float) -> "Vec3":
n = self.norm()
if n > max_mag and n > 1e-9:
return self.scale(max_mag / n)
return Vec3(self.x, self.y, self.z)
# --- distance ---
def distance_to(self, other: "Vec3") -> float:
return (self - other).norm()
def horizontal_distance_to(self, other: "Vec3") -> float:
dx = self.x - other.x
dy = self.y - other.y
return math.sqrt(dx * dx + dy * dy)
# --- serialization ---
def to_dict(self) -> dict[str, float]:
return {"x": round(self.x, 4), "y": round(self.y, 4), "z": round(self.z, 4)}
def __repr__(self) -> str:
return f"Vec3({self.x:.2f}, {self.y:.2f}, {self.z:.2f})"
# ---------------------------------------------------------------------------
# Drone
# ---------------------------------------------------------------------------
@dataclass
class DroneState:
"""Full kinematic + status state of the drone."""
position: Vec3 = field(default_factory=Vec3)
velocity: Vec3 = field(default_factory=Vec3)
battery: float = 100.0
carrying_payload: bool = True
alive: bool = True
def to_dict(self) -> dict[str, Any]:
return {
"position": self.position.to_dict(),
"velocity": self.velocity.to_dict(),
"battery": round(self.battery, 4),
"carrying_payload": self.carrying_payload,
"alive": self.alive,
}
# ---------------------------------------------------------------------------
# World entities
# ---------------------------------------------------------------------------
@dataclass
class BaseStation:
"""Home base where the drone launches, lands, and recharges."""
position: Vec3 = field(default_factory=Vec3)
recharge_radius: float = 20.0
def to_dict(self) -> dict[str, Any]:
return {
"position": self.position.to_dict(),
"recharge_radius": self.recharge_radius,
}
@dataclass
class DeliveryTarget:
"""A responder zone requiring supply delivery."""
id: str = ""
position: Vec3 = field(default_factory=Vec3)
urgency: float = 0.5
delivered: bool = False
delivery_radius: float = 15.0
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"position": self.position.to_dict(),
"urgency": round(self.urgency, 4),
"delivered": self.delivered,
"delivery_radius": self.delivery_radius,
}
@dataclass
class HazardRegion:
"""Wildfire danger zone modeled as a ground-level dome.
The hazard has a horizontal radius and a height. Danger is zero
above ``height`` and outside ``radius``, allowing drones to fly
over fires at sufficient altitude. Within the dome, danger scales
with proximity to the center both horizontally and vertically.
``growth_rate`` controls per-step height increase (metres/step),
simulating fire growth over an episode.
"""
id: str = ""
center: Vec3 = field(default_factory=Vec3)
radius: float = 50.0
severity: float = 0.5
height: float = 80.0
growth_rate: float = 0.0
_current_height: float = field(default=0.0, init=False, repr=False)
def __post_init__(self):
self._current_height = self.height
def reset(self):
"""Reset dynamic state for a new episode."""
self._current_height = self.height
def tick(self):
"""Advance one timestep — grow the fire."""
if self.growth_rate > 0:
self._current_height += self.growth_rate
def contains(self, pos: Vec3) -> bool:
horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5
alt = pos.z - self.center.z
return horiz <= self.radius and 0 <= alt < self._current_height
def danger_factor(self, pos: Vec3) -> float:
"""0 outside the dome, scales up toward the ground-level center."""
horiz = ((pos.x - self.center.x) ** 2 + (pos.y - self.center.y) ** 2) ** 0.5
if horiz >= self.radius:
return 0.0
alt = pos.z - self.center.z
if alt >= self._current_height or alt < 0:
return 0.0
horiz_factor = 1.0 - horiz / self.radius
vert_factor = 1.0 - alt / self._current_height
return self.severity * horiz_factor * vert_factor
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"center": self.center.to_dict(),
"radius": self.radius,
"severity": self.severity,
"height": self.height,
"current_height": round(self._current_height, 2),
"growth_rate": self.growth_rate,
}
@dataclass
class ObstacleVolume:
"""Axis-aligned 3D box that the drone must not enter."""
id: str = ""
min_corner: Vec3 = field(default_factory=Vec3)
max_corner: Vec3 = field(default_factory=Vec3)
kind: str = "building"
def contains(self, pos: Vec3) -> bool:
return (
self.min_corner.x <= pos.x <= self.max_corner.x
and self.min_corner.y <= pos.y <= self.max_corner.y
and self.min_corner.z <= pos.z <= self.max_corner.z
)
@property
def center(self) -> Vec3:
return Vec3(
(self.min_corner.x + self.max_corner.x) / 2,
(self.min_corner.y + self.max_corner.y) / 2,
(self.min_corner.z + self.max_corner.z) / 2,
)
@property
def half_size(self) -> Vec3:
return Vec3(
(self.max_corner.x - self.min_corner.x) / 2,
(self.max_corner.y - self.min_corner.y) / 2,
(self.max_corner.z - self.min_corner.z) / 2,
)
@property
def height(self) -> float:
return self.max_corner.z
def nearest_surface_dist(self, pos: Vec3) -> float:
"""Signed distance to the nearest surface (negative = inside)."""
cx, cy = self.center.x, self.center.y
hx, hy = self.half_size.x, self.half_size.y
dx = max(abs(pos.x - cx) - hx, 0.0)
dy = max(abs(pos.y - cy) - hy, 0.0)
dz_below = max(self.min_corner.z - pos.z, 0.0)
dz_above = max(pos.z - self.max_corner.z, 0.0)
return math.sqrt(dx * dx + dy * dy + (dz_below + dz_above) ** 2)
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"min_corner": self.min_corner.to_dict(),
"max_corner": self.max_corner.to_dict(),
"kind": self.kind,
}
@dataclass
class CylindricalObstacle:
"""Vertical cylinder obstacle — trees, poles, pillars, tanks."""
id: str = ""
center: Vec3 = field(default_factory=Vec3)
radius: float = 10.0
height: float = 50.0
kind: str = "tree"
def contains(self, pos: Vec3) -> bool:
dx = pos.x - self.center.x
dy = pos.y - self.center.y
horiz_dist = math.sqrt(dx * dx + dy * dy)
return horiz_dist <= self.radius and 0 <= pos.z <= self.height
def nearest_surface_dist(self, pos: Vec3) -> float:
dx = pos.x - self.center.x
dy = pos.y - self.center.y
horiz_dist = math.sqrt(dx * dx + dy * dy)
radial_gap = max(horiz_dist - self.radius, 0.0)
vert_gap = max(pos.z - self.height, 0.0) if pos.z > self.height else max(-pos.z, 0.0)
return math.sqrt(radial_gap ** 2 + vert_gap ** 2)
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"center": self.center.to_dict(),
"radius": round(self.radius, 2),
"height": round(self.height, 2),
"kind": self.kind,
}
# ---------------------------------------------------------------------------
# Responder units — dynamic actors that alter mission conditions mid-episode
# ---------------------------------------------------------------------------
RESPONDER_STATUSES = ("stable", "urgent", "critical")
RESPONDER_STATUS_MAP = {"stable": 0.0, "urgent": 0.5, "critical": 1.0}
INTEL_TYPES = (
"none",
"blocked_north", "blocked_south", "blocked_east", "blocked_west",
"safe_north", "safe_south", "safe_east", "safe_west",
"fire_expanded", "fire_receded",
)
INTEL_DIRECTION_VECS = {
"none": (0.0, 0.0),
"blocked_north": (0.0, 1.0), "blocked_south": (0.0, -1.0),
"blocked_east": (1.0, 0.0), "blocked_west": (-1.0, 0.0),
"safe_north": (0.0, 1.0), "safe_south": (0.0, -1.0),
"safe_east": (1.0, 0.0), "safe_west": (-1.0, 0.0),
"fire_expanded": (0.0, 0.0), "fire_receded": (0.0, 0.0),
}
@dataclass
class ScheduledEvent:
"""A future event a responder will trigger at a specific step."""
step: int = 0
event_type: str = ""
payload: dict[str, Any] = field(default_factory=dict)
fired: bool = False
@dataclass
class ResponderUnit:
"""First responder on the ground linked to a delivery target.
Can dynamically alter mission conditions mid-episode:
1. Update urgency of their linked target
2. Relocate the drop-zone (move target position)
3. Broadcast hazard intel (structured approach guidance)
"""
id: str = ""
position: Vec3 = field(default_factory=Vec3)
linked_target_id: str = ""
status: str = "stable"
current_need: str = "supplies"
message: str = ""
can_update_dropzone: bool = False
active: bool = True
latest_intel: str = "none"
intel_severity: float = 0.0
scheduled_events: list[ScheduledEvent] = field(default_factory=list)
def status_code(self) -> float:
return RESPONDER_STATUS_MAP.get(self.status, 0.0)
def intel_direction(self) -> tuple[float, float]:
return INTEL_DIRECTION_VECS.get(self.latest_intel, (0.0, 0.0))
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"position": self.position.to_dict(),
"linked_target_id": self.linked_target_id,
"status": self.status,
"current_need": self.current_need,
"message": self.message,
"can_update_dropzone": self.can_update_dropzone,
"active": self.active,
"latest_intel": self.latest_intel,
"intel_severity": round(self.intel_severity, 4),
}
# ---------------------------------------------------------------------------
# Observation & step diagnostics
# ---------------------------------------------------------------------------
@dataclass
class VarahaObservation:
"""Structured observation returned to the agent each step.
Kept as a dataclass for documentation; the env also offers a plain-dict
path via ``get_observation()`` for maximum serialisation flexibility.
"""
drone_position: Vec3 = field(default_factory=Vec3)
drone_velocity: Vec3 = field(default_factory=Vec3)
battery: float = 100.0
carrying_payload: bool = True
alive: bool = True
targets: list[dict[str, Any]] = field(default_factory=list)
step: int = 0
max_steps: int = 500
def to_dict(self) -> dict[str, Any]:
return {
"drone_position": self.drone_position.to_dict(),
"drone_velocity": self.drone_velocity.to_dict(),
"battery": round(self.battery, 4),
"carrying_payload": self.carrying_payload,
"alive": self.alive,
"targets": self.targets,
"step": self.step,
"max_steps": self.max_steps,
}
@dataclass
class MissionInstruction:
"""Single mission instruction used for long-horizon planning mode."""
id: str = ""
kind: str = ""
description: str = ""
target_id: str = ""
tool_name: str = ""
completed: bool = False
violated: bool = False
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"kind": self.kind,
"description": self.description,
"target_id": self.target_id,
"tool_name": self.tool_name,
"completed": self.completed,
"violated": self.violated,
}
@dataclass
class TracePoint:
"""Single frame of the drone's recorded trajectory."""
step: int = 0
position: Vec3 = field(default_factory=Vec3)
velocity: Vec3 = field(default_factory=Vec3)
battery: float = 100.0
reward: float = 0.0
cumulative_reward: float = 0.0
events: list[str] = field(default_factory=list)
observation: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"step": self.step,
"position": self.position.to_dict(),
"velocity": self.velocity.to_dict(),
"battery": round(self.battery, 4),
"reward": round(self.reward, 4),
"cumulative_reward": round(self.cumulative_reward, 4),
"events": list(self.events),
"observation": self.observation,
}
@dataclass
class StepInfo:
"""Per-step diagnostic info returned alongside the reward."""
collision: bool = False
delivered_target_ids: list[str] = field(default_factory=list)
in_hazard: bool = False
hazard_severity: float = 0.0
reached_base: bool = False
distance_traveled: float = 0.0
tool_call: str = ""
tool_result: dict[str, Any] = field(default_factory=dict)
instruction_completed: int = 0
instruction_total: int = 0
instruction_violations: int = 0
reward_breakdown: dict[str, float] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"collision": self.collision,
"delivered_target_ids": list(self.delivered_target_ids),
"in_hazard": self.in_hazard,
"hazard_severity": round(self.hazard_severity, 4),
"reached_base": self.reached_base,
"distance_traveled": round(self.distance_traveled, 4),
"tool_call": self.tool_call,
"tool_result": self.tool_result,
"instruction_completed": self.instruction_completed,
"instruction_total": self.instruction_total,
"instruction_violations": self.instruction_violations,
"reward_breakdown": {
k: round(v, 4) for k, v in self.reward_breakdown.items()
},
}