dm_control_env / server /drone_forest_environment.py
Leeps's picture
Add render error logging for debugging
45f768e
"""
Drone Navigation Environment (simplified — no trees).
A quadrotor drone flies to a nearby target in an open arena.
The RL policy commands velocity (forward/left/up/turn) while a built-in PD flight
controller handles low-level motor mixing.
"""
import base64
import io
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4
# Configure MuJoCo rendering backend before importing mujoco
if "MUJOCO_GL" not in os.environ and sys.platform != "darwin":
os.environ.setdefault("MUJOCO_GL", "egl")
import numpy as np
try:
from openenv.core.env_server.interfaces import Environment
from ..models import DMControlAction, DMControlObservation, DMControlState
except ImportError:
from openenv.core.env_server.interfaces import Environment
try:
import sys as _sys
from pathlib import Path as _Path
_parent = str(_Path(__file__).parent.parent)
if _parent not in _sys.path:
_sys.path.insert(0, _parent)
from models import DMControlAction, DMControlObservation, DMControlState
except ImportError:
try:
from dm_control_env.models import (
DMControlAction,
DMControlObservation,
DMControlState,
)
except ImportError:
from envs.dm_control_env.models import (
DMControlAction,
DMControlObservation,
DMControlState,
)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
ARENA_HALF = 10.0 # arena is 20x20 m
MAX_ALTITUDE = 8.0
MIN_ALTITUDE = 0.1
TARGET_RADIUS = 0.5 # success if within this distance
TARGET_MIN_DIST = 2.0 # target at least this far from spawn
TARGET_MAX_DIST = 4.0 # target at most this far from spawn
MAX_STEPS = 1000
PHYSICS_DT = 0.002
CONTROL_DT = 0.02 # 50 Hz control
# Velocity limits
MAX_XY_VEL = 3.0 # m/s
MAX_Z_VEL = 2.0 # m/s
MAX_YAW_RATE = 2.0 # rad/s
# Flight-controller PD gains
KP_VEL = 4.0
KD_VEL = 1.5
KP_ATT = 8.0
KD_ATT = 2.0
# Drone physical parameters
DRONE_MASS = 0.48 # total mass (body 0.4 + arms 0.08) close to XML
GRAVITY = 9.81
HOVER_THRUST = DRONE_MASS * GRAVITY / 4.0 # per-motor hover
ARM_LENGTH = 0.14 # distance from CoM to rotor
XML_PATH = str(Path(__file__).parent / "drone_forest.xml")
class DroneForestEnvironment(Environment):
"""Drone navigates to a nearby target in an open arena."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
render_height: Optional[int] = None,
render_width: Optional[int] = None,
**kwargs,
):
self._model = None
self._data = None
self._render_height = render_height or int(
os.environ.get("DMCONTROL_RENDER_HEIGHT", "512")
)
self._render_width = render_width or int(
os.environ.get("DMCONTROL_RENDER_WIDTH", "512")
)
self._include_pixels = False
self._step_count = 0
self._prev_dist = None
self._target_pos = np.zeros(3)
self._done = False
self._rng = np.random.RandomState()
self._state = DMControlState(
episode_id=str(uuid4()),
step_count=0,
domain_name="drone_forest",
task_name="navigate",
)
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
def _ensure_model(self):
"""Load MuJoCo model if not loaded yet."""
if self._model is not None:
return
import mujoco
self._model = mujoco.MjModel.from_xml_path(XML_PATH)
self._data = mujoco.MjData(self._model)
# Precompute body / geom ids
self._drone_body_id = mujoco.mj_name2id(
self._model, mujoco.mjtObj.mjOBJ_BODY, "drone"
)
self._target_body_id = mujoco.mj_name2id(
self._model, mujoco.mjtObj.mjOBJ_BODY, "target"
)
self._drone_body_geom_id = mujoco.mj_name2id(
self._model, mujoco.mjtObj.mjOBJ_GEOM, "drone_body"
)
self._ground_geom_id = mujoco.mj_name2id(
self._model, mujoco.mjtObj.mjOBJ_GEOM, "ground"
)
# Set state metadata
self._state.action_spec = {
"shape": [4],
"dtype": "float64",
"minimum": [-1.0, -1.0, -1.0, -1.0],
"maximum": [1.0, 1.0, 1.0, 1.0],
"name": "velocity_command",
}
self._state.observation_spec = {
"position": {"shape": [3], "dtype": "float64"},
"velocity": {"shape": [3], "dtype": "float64"},
"orientation": {"shape": [3], "dtype": "float64"},
"angular_velocity": {"shape": [3], "dtype": "float64"},
"target_relative": {"shape": [3], "dtype": "float64"},
}
self._state.physics_timestep = PHYSICS_DT
self._state.control_timestep = CONTROL_DT
# ------------------------------------------------------------------
# Target placement
# ------------------------------------------------------------------
def _place_target(self):
"""Place target close to spawn (2-4m away)."""
import mujoco
angle = self._rng.uniform(0, 2 * np.pi)
dist = self._rng.uniform(TARGET_MIN_DIST, TARGET_MAX_DIST)
tx = dist * np.cos(angle)
ty = dist * np.sin(angle)
tz = self._rng.uniform(1.0, 2.5)
self._target_pos = np.array([tx, ty, tz])
self._model.body_pos[self._target_body_id] = self._target_pos.copy()
# Recompute derived quantities after changing body positions
mujoco.mj_forward(self._model, self._data)
# ------------------------------------------------------------------
# Flight controller
# ------------------------------------------------------------------
def _flight_controller(self, cmd: np.ndarray) -> np.ndarray:
"""
Convert velocity commands [vx, vy, vz, yaw_rate] in [-1,1]
to 4 motor thrusts.
"""
# Scale commands
vx_cmd = cmd[0] * MAX_XY_VEL
vy_cmd = cmd[1] * MAX_XY_VEL
vz_cmd = cmd[2] * MAX_Z_VEL
yaw_rate_cmd = cmd[3] * MAX_YAW_RATE
# Current state
pos = self._data.qpos[:3].copy()
quat = self._data.qpos[3:7].copy() # w, x, y, z
vel = self._data.qvel[:3].copy()
ang_vel = self._data.qvel[3:6].copy()
# Extract yaw from quaternion
roll, pitch, yaw = self._quat_to_euler(quat)
# Rotate desired world-frame velocity into body XY
cos_yaw, sin_yaw = np.cos(yaw), np.sin(yaw)
# World-frame desired velocity
vx_world = vx_cmd * cos_yaw - vy_cmd * sin_yaw
vy_world = vx_cmd * sin_yaw + vy_cmd * cos_yaw
# Velocity error in world frame
vx_err = vx_world - vel[0]
vy_err = vy_world - vel[1]
vz_err = vz_cmd - vel[2]
# Desired roll/pitch from XY velocity error (small angle approx)
desired_pitch = np.clip(KP_VEL * vx_err, -0.5, 0.5)
desired_roll = np.clip(-KP_VEL * vy_err, -0.5, 0.5)
# Attitude PD
roll_err = desired_roll - roll
pitch_err = desired_pitch - pitch
yaw_rate_err = yaw_rate_cmd - ang_vel[2]
torque_roll = KP_ATT * roll_err - KD_ATT * ang_vel[0]
torque_pitch = KP_ATT * pitch_err - KD_ATT * ang_vel[1]
torque_yaw = KP_ATT * yaw_rate_err
# Collective thrust: hover + vertical velocity correction
thrust = DRONE_MASS * GRAVITY + KP_VEL * vz_err * DRONE_MASS
# Quadrotor mixer: convert thrust + torques to 4 motor thrusts
# Layout: FR(+x,-y), FL(+x,+y), BR(-x,-y), BL(-x,+y)
L = ARM_LENGTH
t_fr = thrust / 4.0 + torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) - torque_yaw / 4.0
t_fl = thrust / 4.0 + torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) + torque_yaw / 4.0
t_br = thrust / 4.0 - torque_pitch / (4.0 * L) - torque_roll / (4.0 * L) + torque_yaw / 4.0
t_bl = thrust / 4.0 - torque_pitch / (4.0 * L) + torque_roll / (4.0 * L) - torque_yaw / 4.0
# Clamp to actuator range [0, 3]
motors = np.clip([t_fr, t_fl, t_br, t_bl], 0.0, 3.0)
return motors
@staticmethod
def _quat_to_euler(quat: np.ndarray):
"""Convert quaternion [w, x, y, z] to Euler angles [roll, pitch, yaw]."""
w, x, y, z = quat
# Roll (x-axis rotation)
sinr = 2.0 * (w * x + y * z)
cosr = 1.0 - 2.0 * (x * x + y * y)
roll = np.arctan2(sinr, cosr)
# Pitch (y-axis rotation)
sinp = 2.0 * (w * y - z * x)
sinp = np.clip(sinp, -1.0, 1.0)
pitch = np.arcsin(sinp)
# Yaw (z-axis rotation)
siny = 2.0 * (w * z + x * y)
cosy = 1.0 - 2.0 * (y * y + z * z)
yaw = np.arctan2(siny, cosy)
return roll, pitch, yaw
# ------------------------------------------------------------------
# Observations
# ------------------------------------------------------------------
def _get_obs(self) -> Dict[str, List[float]]:
pos = self._data.qpos[:3].copy()
vel = self._data.qvel[:3].copy()
quat = self._data.qpos[3:7].copy()
ang_vel = self._data.qvel[3:6].copy()
roll, pitch, yaw = self._quat_to_euler(quat)
target_rel = self._target_pos - pos
return {
"position": pos.tolist(),
"velocity": vel.tolist(),
"orientation": [float(roll), float(pitch), float(yaw)],
"angular_velocity": ang_vel.tolist(),
"target_relative": target_rel.tolist(),
}
# ------------------------------------------------------------------
# Collision detection
# ------------------------------------------------------------------
def _check_collisions(self) -> bool:
"""Return True if drone collides with ground."""
for i in range(self._data.ncon):
contact = self._data.contact[i]
g1, g2 = contact.geom1, contact.geom2
pair = {g1, g2}
if self._drone_body_geom_id not in pair:
continue
other = (pair - {self._drone_body_geom_id}).pop()
if other == self._ground_geom_id:
return True
return False
# ------------------------------------------------------------------
# Reward
# ------------------------------------------------------------------
def _compute_reward(self, pos: np.ndarray) -> float:
dist = np.linalg.norm(self._target_pos - pos)
reward = 0.0
# +0.1 if drone moved closer to target this step, 0.0 otherwise
if self._prev_dist is not None and dist < self._prev_dist:
reward = 0.1
self._prev_dist = dist
return float(reward)
# ------------------------------------------------------------------
# Termination
# ------------------------------------------------------------------
def _check_termination(self, pos: np.ndarray):
"""Returns (done, bonus_reward)."""
dist = np.linalg.norm(self._target_pos - pos)
# Success
if dist < TARGET_RADIUS:
return True, 100.0
# Collision
if self._check_collisions():
return True, -50.0
# Out of bounds
if (abs(pos[0]) > ARENA_HALF or abs(pos[1]) > ARENA_HALF or
pos[2] > MAX_ALTITUDE or pos[2] < MIN_ALTITUDE):
return True, -10.0
# Max steps
if self._step_count >= MAX_STEPS:
return True, 0.0
return False, 0.0
# ------------------------------------------------------------------
# Core interface
# ------------------------------------------------------------------
def reset(
self,
domain_name: Optional[str] = None,
task_name: Optional[str] = None,
seed: Optional[int] = None,
render: bool = False,
**kwargs,
) -> DMControlObservation:
import mujoco
self._ensure_model()
self._include_pixels = render
if seed is not None:
self._rng = np.random.RandomState(seed)
# Reset data to defaults
mujoco.mj_resetData(self._model, self._data)
# Place target nearby
self._place_target()
# Place drone at origin, altitude 1.5
self._data.qpos[:3] = [0.0, 0.0, 1.5]
self._data.qpos[3:7] = [1.0, 0.0, 0.0, 0.0] # identity quaternion
self._data.qvel[:] = 0.0
mujoco.mj_forward(self._model, self._data)
self._step_count = 0
pos = self._data.qpos[:3].copy()
self._prev_dist = float(np.linalg.norm(self._target_pos - pos))
self._done = False
self._state = DMControlState(
episode_id=str(uuid4()),
step_count=0,
domain_name="drone_forest",
task_name="navigate",
action_spec=self._state.action_spec,
observation_spec=self._state.observation_spec,
physics_timestep=PHYSICS_DT,
control_timestep=CONTROL_DT,
)
obs = self._get_obs()
pixels = self._render_pixels() if render else None
return DMControlObservation(
observations=obs,
pixels=pixels,
reward=0.0,
done=False,
)
def step(
self,
action: DMControlAction,
render: bool = False,
**kwargs,
) -> DMControlObservation:
import mujoco
if self._model is None or self._data is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
# Clip action to [-1, 1]
cmd = np.clip(np.array(action.values[:4], dtype=np.float64), -1.0, 1.0)
# Run flight controller to get motor thrusts
motors = self._flight_controller(cmd)
# Set actuator controls
self._data.ctrl[:4] = motors
# Step physics for one control timestep (multiple physics substeps)
n_substeps = int(CONTROL_DT / PHYSICS_DT)
for _ in range(n_substeps):
mujoco.mj_step(self._model, self._data)
self._step_count += 1
self._state.step_count = self._step_count
pos = self._data.qpos[:3].copy()
# Compute reward and check termination
reward = self._compute_reward(pos)
done, bonus = self._check_termination(pos)
reward += bonus
self._done = done
obs = self._get_obs()
pixels = self._render_pixels() if (render or self._include_pixels) else None
return DMControlObservation(
observations=obs,
pixels=pixels,
reward=float(reward),
done=done,
)
async def reset_async(self, **kwargs) -> DMControlObservation:
if sys.platform == "darwin":
return self.reset(**kwargs)
else:
import asyncio
return await asyncio.to_thread(self.reset, **kwargs)
async def step_async(self, action: DMControlAction, render: bool = False, **kwargs) -> DMControlObservation:
if sys.platform == "darwin":
return self.step(action, render=render, **kwargs)
else:
import asyncio
return await asyncio.to_thread(self.step, action, render=render, **kwargs)
# ------------------------------------------------------------------
# Rendering
# ------------------------------------------------------------------
def _render_pixels(self) -> Optional[str]:
try:
import mujoco
renderer = mujoco.Renderer(self._model, height=self._render_height, width=self._render_width)
renderer.update_scene(self._data, camera="tracking")
frame = renderer.render()
renderer.close()
from PIL import Image
img = Image.fromarray(frame)
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
except Exception as e:
import traceback
print(f"[render error] {e}")
traceback.print_exc()
return None
@property
def state(self) -> DMControlState:
return self._state
def close(self) -> None:
self._model = None
self._data = None
def __del__(self):
try:
self.close()
except Exception:
pass