RoboMME / src /robomme /env_record_wrapper /OraclePlannerDemonstrationWrapper.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
import gymnasium as gym
import numpy as np
import torch
from robomme.robomme_env.utils.vqa_options import get_vqa_options
from mani_skill.examples.motionplanning.panda.motionplanner import (
PandaArmMotionPlanningSolver,
)
from mani_skill.examples.motionplanning.panda.motionplanner_stick import (
PandaStickMotionPlanningSolver,
)
from ..robomme_env.utils import planner_denseStep
from ..robomme_env.utils.oracle_action_matcher import (
find_exact_label_option_index,
)
from ..robomme_env.utils.choice_action_mapping import select_target_with_pixel
from ..logging_utils import logger
# -----------------------------------------------------------------------------
# Module: Oracle Planner Demonstration Wrapper
# Connect Robomme Oracle planning logic in Gym environment, support step-by-step observation collection.
# Oracle logic below is inlined from history_bench_sim.oracle_logic, cooperating with
# planner_denseStep, aggregating multiple internal env.step calls into a unified batch return.
# -----------------------------------------------------------------------------
class OraclePlannerDemonstrationWrapper(gym.Wrapper):
"""
Wrap Robomme environment with Oracle planning logic into Gym Wrapper for demonstration/evaluation;
Input to step is command_dict (containing choice and optional pixel point).
step returns obs as dict-of-lists and reward/terminated/truncated as last-step values.
"""
def __init__(self, env, env_id, gui_render=True):
super().__init__(env)
self.env_id = env_id
self.gui_render = gui_render
self.planner = None
self.language_goal = None
# State: current available options
self.available_options = []
self._oracle_screw_max_attempts = 3
self._oracle_rrt_max_attempts = 3
self._front_camera_intrinsic_cv = None
self._front_camera_extrinsic_cv = None
self._front_rgb_shape = None
# Action/Observation space (Empty Dict here, agreed externally)
self.action_space = gym.spaces.Dict({})
self.observation_space = gym.spaces.Dict({})
def _wrap_planner_with_screw_then_rrt_retry(self, planner, screw_failure_exc):
original_move_to_pose_with_screw = planner.move_to_pose_with_screw
original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar
def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs):
for attempt in range(1, self._oracle_screw_max_attempts + 1):
try:
result = original_move_to_pose_with_screw(*args, **kwargs)
except screw_failure_exc as exc:
logger.debug(
f"[OraclePlannerWrapper] screw planning failed "
f"(attempt {attempt}/{self._oracle_screw_max_attempts}): {exc}"
)
continue
if isinstance(result, int) and result == -1:
logger.debug(
f"[OraclePlannerWrapper] screw planning returned -1 "
f"(attempt {attempt}/{self._oracle_screw_max_attempts})"
)
continue
return result
logger.debug(
"[OraclePlannerWrapper] screw planning exhausted; "
f"fallback to RRT* (max {self._oracle_rrt_max_attempts} attempts)"
)
for attempt in range(1, self._oracle_rrt_max_attempts + 1):
try:
result = original_move_to_pose_with_rrt(*args, **kwargs)
except Exception as exc:
logger.debug(
f"[OraclePlannerWrapper] RRT* planning failed "
f"(attempt {attempt}/{self._oracle_rrt_max_attempts}): {exc}"
)
continue
if isinstance(result, int) and result == -1:
logger.debug(
f"[OraclePlannerWrapper] RRT* planning returned -1 "
f"(attempt {attempt}/{self._oracle_rrt_max_attempts})"
)
continue
return result
raise RuntimeError(
"[OraclePlannerWrapper] screw->RRT* planning exhausted; "
f"screw_attempts={self._oracle_screw_max_attempts}, "
f"rrt_attempts={self._oracle_rrt_max_attempts}"
)
planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry
return planner
def reset(self, **kwargs):
# Prefer fail-aware planners; fallback to base planners if import fails.
try:
from ..robomme_env.utils.planner_fail_safe import (
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
ScrewPlanFailure,
)
except Exception as exc:
logger.debug(
"[OraclePlannerWrapper] Warning: failed to import planner_fail_safe, "
f"fallback to base planners: {exc}"
)
FailAwarePandaArmMotionPlanningSolver = PandaArmMotionPlanningSolver
FailAwarePandaStickMotionPlanningSolver = PandaStickMotionPlanningSolver
class ScrewPlanFailure(RuntimeError):
"""Placeholder exception type when fail-aware planner import is unavailable."""
# Select stick or arm planner based on env_id and initialize.
if self.env_id in ("PatternLock", "RouteStick"):
self.planner = FailAwarePandaStickMotionPlanningSolver(
self.env,
debug=False,
vis=self.gui_render,
base_pose=self.env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
self.planner = FailAwarePandaArmMotionPlanningSolver(
self.env,
debug=False,
vis=self.gui_render,
base_pose=self.env.unwrapped.agent.robot.pose,
visualize_target_grasp_pose=False,
print_env_info=False,
)
self._wrap_planner_with_screw_then_rrt_retry(
self.planner,
screw_failure_exc=ScrewPlanFailure,
)
ret = self.env.reset(**kwargs)
if isinstance(ret, tuple) and len(ret) == 2:
obs, info = ret
else:
obs, info = ret, {}
self._update_front_camera_cache(obs_like=obs, info_like=info)
self._build_step_options()
if isinstance(info, dict):
info["available_multi_choices"] = self.available_options
return obs, info
@staticmethod
def _flatten_info_batch(info_batch: dict) -> dict:
return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
@staticmethod
def _take_last_step_value(value):
if isinstance(value, torch.Tensor):
if value.numel() == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, np.ndarray):
if value.size == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, (list, tuple)):
return value[-1] if value else value
return value
@staticmethod
def _to_numpy(value):
if value is None:
return None
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
return np.asarray(value)
@staticmethod
def _take_last_columnar(value):
if isinstance(value, list):
return value[-1] if value else None
return value
@classmethod
def _normalize_intrinsic_cv(cls, intrinsic_like):
intrinsic = cls._to_numpy(intrinsic_like)
if intrinsic is None:
return None
intrinsic = intrinsic.reshape(-1)
if intrinsic.size < 9:
return None
intrinsic = intrinsic[:9].reshape(3, 3)
if not np.all(np.isfinite(intrinsic)):
return None
return intrinsic.astype(np.float64, copy=False)
@classmethod
def _normalize_extrinsic_cv(cls, extrinsic_like):
extrinsic = cls._to_numpy(extrinsic_like)
if extrinsic is None:
return None
extrinsic = extrinsic.reshape(-1)
if extrinsic.size < 12:
return None
extrinsic = extrinsic[:12].reshape(3, 4)
if not np.all(np.isfinite(extrinsic)):
return None
return extrinsic.astype(np.float64, copy=False)
def _update_front_camera_cache(self, obs_like=None, info_like=None):
obs_dict = obs_like if isinstance(obs_like, dict) else {}
info_dict = info_like if isinstance(info_like, dict) else {}
front_rgb = self._take_last_columnar(obs_dict.get("front_rgb_list"))
front_rgb_np = self._to_numpy(front_rgb)
if front_rgb_np is not None and front_rgb_np.ndim >= 2:
self._front_rgb_shape = tuple(front_rgb_np.shape[:2])
front_extrinsic = self._take_last_columnar(
obs_dict.get("front_camera_extrinsic_list")
)
front_extrinsic_np = self._normalize_extrinsic_cv(front_extrinsic)
if front_extrinsic_np is not None:
self._front_camera_extrinsic_cv = front_extrinsic_np
front_intrinsic = self._take_last_columnar(info_dict.get("front_camera_intrinsic"))
front_intrinsic_np = self._normalize_intrinsic_cv(front_intrinsic)
if front_intrinsic_np is not None:
self._front_camera_intrinsic_cv = front_intrinsic_np
@staticmethod
def _empty_target():
return {
"obj": None,
"name": None,
"seg_id": None,
"position": None,
"match_distance": None,
"selection_mode": None,
}
def _build_step_options(self):
selected_target = self._empty_target()
solve_options = get_vqa_options(self.env, self.planner, selected_target, self.env_id)
self.available_options = [
{"label": opt.get("label"), "action": opt.get("action", "Unknown"), "need_parameter": bool(opt.get("available"))}
for opt in solve_options
]
return selected_target, solve_options
def _resolve_command(self, command_dict, solve_options):
if not isinstance(command_dict, dict):
return None, None
if "choice" not in command_dict:
return None, None
target_choice = command_dict.get("choice")
if not isinstance(target_choice, str):
return None, None
target_choice = target_choice.strip()
if not target_choice:
return None, None
target_label = target_choice.lower()
found_idx = find_exact_label_option_index(target_label, solve_options)
if found_idx == -1:
logger.debug(
f"Error: Choice '{target_choice}' not found in current options by exact label match."
)
return None, None
point = command_dict.get("point")
if point is None:
return found_idx, None
if not isinstance(point, (list, tuple, np.ndarray)) or len(point) < 2:
return found_idx, None
try:
y = float(point[0])
x = float(point[1])
except (TypeError, ValueError):
return found_idx, None
if not np.isfinite(x) or not np.isfinite(y):
return found_idx, None
# select_target_with_pixel expects [x, y].
return found_idx, [int(np.rint(x)), int(np.rint(y))]
def _apply_position_target(self, selected_target, option, target_pixel):
if target_pixel is None:
return
best_cand = select_target_with_pixel(
available=option.get("available"),
pixel_like=target_pixel,
intrinsic_cv=self._front_camera_intrinsic_cv,
extrinsic_cv=self._front_camera_extrinsic_cv,
image_shape=self._front_rgb_shape,
)
if best_cand is not None:
selected_target.update(best_cand)
def _execute_selected_option(self, option_idx, solve_options):
option = solve_options[option_idx]
logger.debug(f"Executing option: {option_idx + 1} - {option.get('action')}")
result = planner_denseStep._run_with_dense_collection(
self.planner,
lambda: option.get("solve")(),
)
if result == -1:
action_text = option.get("action", "Unknown")
raise RuntimeError(
f"Oracle solve failed after screw->RRT* retries for env '{self.env_id}', "
f"action '{action_text}' (index {option_idx + 1})."
)
return result
def _post_eval(self):
self.env.unwrapped.evaluate()
evaluation = self.env.unwrapped.evaluate(solve_complete_eval=True)
logger.debug(f"Evaluation result: {evaluation}")
def _format_step_output(self, batch):
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
self._update_front_camera_cache(obs_like=obs_batch, info_like=info_batch)
info_flat = self._flatten_info_batch(info_batch)
info_flat["available_multi_choices"] = getattr(self, "available_options", [])
return (
obs_batch,
self._take_last_step_value(reward_batch),
self._take_last_step_value(terminated_batch),
self._take_last_step_value(truncated_batch),
info_flat,
)
def step(self, action):
"""
Execute one step: action is command_dict, must contain "choice", optional
pixel `point=[y, x]` in front_rgb.
Return last-step signals for reward/terminated/truncated while keeping obs as dict-of-lists.
"""
# 1) Build solver options once and prepare a mutable selected_target holder for solve() closures.
selected_target, solve_options = self._build_step_options()
# 2) Validate/resolve the incoming command into (option index, optional target position).
found_idx, target_pixel = self._resolve_command(action, solve_options)
# 3) For invalid command or unmatched choice, keep legacy behavior: return an empty dense batch.
if found_idx is None:
return self._format_step_output(planner_denseStep.empty_step_batch())
# 4) If a point is provided, map it to the nearest candidate target.
option = solve_options[found_idx]
self._apply_position_target(
selected_target=selected_target,
option=option,
target_pixel=target_pixel,
)
requires_target = "available" in option
if requires_target:
if target_pixel is None:
raise ValueError(
f"Multi-choice action '{option.get('action', 'Unknown')}' requires "
"a target pixel point=[y, x], but command did not provide it."
)
if selected_target.get("obj") is None:
raise ValueError(
f"Multi-choice action '{option.get('action', 'Unknown')}' could not match "
f"any available candidate from point={target_pixel}."
)
# 5) Execute selected solve() with dense step collection; raise on solve == -1.
batch = self._execute_selected_option(found_idx, solve_options)
# 6) Run post-solve environment evaluation to keep existing side effects and logging.
self._post_eval()
# 7) Convert batch to wrapper output contract (last reward/terminated/truncated + flattened info).
print("step!!!!!!")
return self._format_step_output(batch)