| | |
| | |
| | |
| |
|
| | import os |
| | from typing import Any, Optional |
| |
|
| |
|
| |
|
| | import os |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| |
|
| |
|
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| |
|
| | from robomme.robomme_env import * |
| | from robomme.robomme_env.utils import * |
| | from robomme.env_record_wrapper import ( |
| | BenchmarkEnvBuilder, |
| | EpisodeDatasetResolver, |
| | ) |
| | from robomme.env_record_wrapper.OraclePlannerDemonstrationWrapper import ( |
| | OraclePlannerDemonstrationWrapper, |
| | ) |
| | from robomme.robomme_env.utils.choice_action_mapping import ( |
| | _unique_candidates, |
| | extract_actor_position_xyz, |
| | project_world_to_pixel, |
| | select_target_with_pixel, |
| | ) |
| | from robomme.robomme_env.utils.save_reset_video import save_robomme_video |
| |
|
| | |
| | ACTION_SPACE = "joint_angle" |
| |
|
| |
|
| | GUI_RENDER = False |
| |
|
| | DATASET_ROOT = "/data/hongzefu/data_0226" |
| |
|
| | DEFAULT_ENV_IDS = [ |
| | |
| | |
| | |
| | "BinFill", |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| |
|
| | OUT_VIDEO_DIR = "/data/hongzefu/dataset_replay" |
| | MAX_STEPS = 1000 |
| |
|
| |
|
| | def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]: |
| | if not isinstance(choice_action, dict): |
| | return None |
| | choice = choice_action.get("choice") |
| | if not isinstance(choice, str) or not choice.strip(): |
| | return None |
| | point = choice_action.get("point") |
| | if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2: |
| | return None |
| | return choice_action |
| |
|
| |
|
| | def _to_numpy_copy(value: Any) -> np.ndarray: |
| | if isinstance(value, torch.Tensor): |
| | value = value.detach().cpu().numpy() |
| | else: |
| | value = np.asarray(value) |
| | return np.array(value, copy=True) |
| |
|
| |
|
| | def _to_frame_list(frames_like: Any) -> list[np.ndarray]: |
| | if frames_like is None: |
| | return [] |
| | if isinstance(frames_like, torch.Tensor): |
| | arr = frames_like.detach().cpu().numpy() |
| | if arr.ndim == 3: |
| | return [np.array(arr, copy=True)] |
| | if arr.ndim == 4: |
| | return [np.array(x, copy=True) for x in arr] |
| | return [] |
| | if isinstance(frames_like, np.ndarray): |
| | if frames_like.ndim == 3: |
| | return [np.array(frames_like, copy=True)] |
| | if frames_like.ndim == 4: |
| | return [np.array(x, copy=True) for x in frames_like] |
| | return [] |
| | if isinstance(frames_like, (list, tuple)): |
| | out = [] |
| | for frame in frames_like: |
| | if frame is None: |
| | continue |
| | out.append(_to_numpy_copy(frame)) |
| | return out |
| | try: |
| | arr = np.asarray(frames_like) |
| | except Exception: |
| | return [] |
| | if arr.ndim == 3: |
| | return [np.array(arr, copy=True)] |
| | if arr.ndim == 4: |
| | return [np.array(x, copy=True) for x in arr] |
| | return [] |
| |
|
| |
|
| | def _normalize_pixel_xy(pixel_like: Any) -> Optional[list[int]]: |
| | if not isinstance(pixel_like, (list, tuple, np.ndarray)): |
| | return None |
| | if len(pixel_like) < 2: |
| | return None |
| | try: |
| | x = float(pixel_like[0]) |
| | y = float(pixel_like[1]) |
| | except (TypeError, ValueError): |
| | return None |
| | if not np.isfinite(x) or not np.isfinite(y): |
| | return None |
| | return [int(np.rint(x)), int(np.rint(y))] |
| |
|
| |
|
| | def _normalize_point_yx_to_pixel_xy(point_like: Any) -> Optional[list[int]]: |
| | if not isinstance(point_like, (list, tuple, np.ndarray)): |
| | return None |
| | if len(point_like) < 2: |
| | return None |
| | try: |
| | y = float(point_like[0]) |
| | x = float(point_like[1]) |
| | except (TypeError, ValueError): |
| | return None |
| | if not np.isfinite(x) or not np.isfinite(y): |
| | return None |
| | return [int(np.rint(x)), int(np.rint(y))] |
| |
|
| |
|
| | def _find_oracle_wrapper(env_like: Any) -> Optional[OraclePlannerDemonstrationWrapper]: |
| | current = env_like |
| | visited: set[int] = set() |
| | for _ in range(16): |
| | if current is None: |
| | return None |
| | if isinstance(current, OraclePlannerDemonstrationWrapper): |
| | return current |
| | obj_id = id(current) |
| | if obj_id in visited: |
| | return None |
| | visited.add(obj_id) |
| | current = getattr(current, "env", None) |
| | return None |
| |
|
| |
|
| | def _collect_multi_choice_visualization( |
| | env_like: Any, |
| | command: dict[str, Any], |
| | ) -> tuple[list[list[int]], Optional[list[int]], Optional[list[int]]]: |
| | clicked_pixel = _normalize_point_yx_to_pixel_xy(command.get("point")) |
| | oracle_wrapper = _find_oracle_wrapper(env_like) |
| | if oracle_wrapper is None: |
| | return [], clicked_pixel, None |
| |
|
| | try: |
| | _selected_target, solve_options = oracle_wrapper._build_step_options() |
| | found_idx, _ = oracle_wrapper._resolve_command(command, solve_options) |
| | except Exception: |
| | return [], clicked_pixel, None |
| |
|
| | if found_idx is None or found_idx < 0 or found_idx >= len(solve_options): |
| | return [], clicked_pixel, None |
| |
|
| | option = solve_options[found_idx] |
| | available = option.get("available") |
| | intrinsic_cv = getattr(oracle_wrapper, "_front_camera_intrinsic_cv", None) |
| | extrinsic_cv = getattr(oracle_wrapper, "_front_camera_extrinsic_cv", None) |
| | image_shape = getattr(oracle_wrapper, "_front_rgb_shape", None) |
| |
|
| | candidate_pixels: list[list[int]] = [] |
| | if available is not None: |
| | for actor in _unique_candidates(available): |
| | actor_pos = extract_actor_position_xyz(actor) |
| | if actor_pos is None: |
| | continue |
| | projected = project_world_to_pixel( |
| | actor_pos, |
| | intrinsic_cv=intrinsic_cv, |
| | extrinsic_cv=extrinsic_cv, |
| | image_shape=image_shape, |
| | ) |
| | if projected is None: |
| | continue |
| | candidate_pixels.append([int(projected[0]), int(projected[1])]) |
| |
|
| | matched_pixel: Optional[list[int]] = None |
| | if available is not None and clicked_pixel is not None: |
| | matched = select_target_with_pixel( |
| | available=available, |
| | pixel_like=clicked_pixel, |
| | intrinsic_cv=intrinsic_cv, |
| | extrinsic_cv=extrinsic_cv, |
| | image_shape=image_shape, |
| | ) |
| | if isinstance(matched, dict): |
| | matched_pixel = _normalize_pixel_xy(matched.get("projected_pixel")) |
| |
|
| | return candidate_pixels, clicked_pixel, matched_pixel |
| |
|
| |
|
| | def _make_blackboard(frame_like: Any) -> np.ndarray: |
| | frame = _to_numpy_copy(frame_like) |
| | if frame.ndim < 2: |
| | return np.zeros((1, 1, 3), dtype=np.uint8) |
| | h, w = int(frame.shape[0]), int(frame.shape[1]) |
| | if h <= 0 or w <= 0: |
| | return np.zeros((1, 1, 3), dtype=np.uint8) |
| | return np.zeros((h, w, 3), dtype=np.uint8) |
| |
|
| |
|
| | def _draw_candidate_blackboard( |
| | frame_like: Any, |
| | candidate_pixels: list[list[int]], |
| | ) -> np.ndarray: |
| | board = _make_blackboard(frame_like) |
| | for pixel in candidate_pixels: |
| | if len(pixel) < 2: |
| | continue |
| | cv2.circle(board, (int(pixel[0]), int(pixel[1])), 4, (0, 255, 255), 1) |
| | return board |
| |
|
| |
|
| | def _draw_selection_blackboard( |
| | frame_like: Any, |
| | clicked_pixel: Optional[list[int]], |
| | matched_pixel: Optional[list[int]], |
| | ) -> np.ndarray: |
| | board = _make_blackboard(frame_like) |
| | if clicked_pixel is not None: |
| | cv2.drawMarker( |
| | board, |
| | (int(clicked_pixel[0]), int(clicked_pixel[1])), |
| | (255, 255, 0), |
| | markerType=cv2.MARKER_TILTED_CROSS, |
| | markerSize=10, |
| | thickness=1, |
| | ) |
| | if matched_pixel is not None: |
| | cv2.circle(board, (int(matched_pixel[0]), int(matched_pixel[1])), 5, (255, 0, 0), 2) |
| | return board |
| |
|
| |
|
| |
|
| |
|
| | def main(): |
| | from robomme.logging_utils import setup_logging |
| | setup_logging(level="DEBUG") |
| | env_id_list = BenchmarkEnvBuilder.get_task_list() |
| | print(f"Running envs: {env_id_list}") |
| | print(f"Using action_space: {ACTION_SPACE}") |
| |
|
| | |
| | for env_id in DEFAULT_ENV_IDS: |
| | env_builder = BenchmarkEnvBuilder( |
| | env_id=env_id, |
| | dataset="train", |
| | action_space=ACTION_SPACE, |
| | gui_render=GUI_RENDER, |
| | ) |
| | episode_count = env_builder.get_episode_num() |
| | print(f"[{env_id}] episode_count from metadata: {episode_count}") |
| |
|
| | env = None |
| | for episode in range(episode_count): |
| | if episode !=15: |
| | continue |
| |
|
| | env = env_builder.make_env_for_episode( |
| | episode, |
| | max_steps=MAX_STEPS, |
| | include_maniskill_obs=True, |
| | include_front_depth=True, |
| | include_wrist_depth=True, |
| | include_front_camera_extrinsic=True, |
| | include_wrist_camera_extrinsic=True, |
| | include_available_multi_choices=True, |
| | include_front_camera_intrinsic=True, |
| | include_wrist_camera_intrinsic=True, |
| | ) |
| | try: |
| | dataset_resolver = EpisodeDatasetResolver( |
| | env_id=env_id, |
| | episode=episode, |
| | dataset_directory=DATASET_ROOT, |
| | ) |
| | except KeyError as e: |
| | print(f"[{env_id}] Episode {episode} missing in H5, skipping. ({e})") |
| | if env is not None: |
| | env.close() |
| | continue |
| |
|
| | |
| | |
| | |
| | obs, info = env.reset() |
| |
|
| | |
| | maniskill_obs = obs["maniskill_obs"] |
| | front_rgb_list = _to_frame_list(obs["front_rgb_list"]) |
| | wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"]) |
| | front_depth_list = obs["front_depth_list"] |
| | wrist_depth_list = obs["wrist_depth_list"] |
| | end_effector_pose_raw = obs["end_effector_pose_raw"] |
| | eef_state_list = obs["eef_state_list"] |
| | joint_state_list = obs["joint_state_list"] |
| | |
| | gripper_state_list = obs["gripper_state_list"] |
| | front_camera_extrinsic_list = obs["front_camera_extrinsic_list"] |
| | wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"] |
| |
|
| | |
| | task_goal = info["task_goal"] |
| | simple_subgoal_online = info["simple_subgoal_online"] |
| | grounded_subgoal_online = info["grounded_subgoal_online"] |
| | available_multi_choices = info.get("available_multi_choices") |
| | front_camera_intrinsic = info["front_camera_intrinsic"] |
| | wrist_camera_intrinsic = info["wrist_camera_intrinsic"] |
| | status = info.get("status") |
| |
|
| |
|
| | |
| | reset_base_frames = [_to_numpy_copy(f) for f in front_rgb_list] |
| | reset_wrist_frames = [_to_numpy_copy(f) for f in wrist_rgb_list] |
| | reset_right_frames = ( |
| | [_make_blackboard(f) for f in reset_base_frames] |
| | if ACTION_SPACE == "multi_choice" |
| | else None |
| | ) |
| | reset_far_right_frames = ( |
| | [_make_blackboard(f) for f in reset_base_frames] |
| | if ACTION_SPACE == "multi_choice" |
| | else None |
| | ) |
| | reset_subgoal_grounded = [grounded_subgoal_online] * len(front_rgb_list) |
| |
|
| | step = 0 |
| | episode_success = False |
| | rollout_base_frames: list[np.ndarray] = [] |
| | rollout_wrist_frames: list[np.ndarray] = [] |
| | rollout_right_frames: list[np.ndarray] = [] |
| | rollout_far_right_frames: list[np.ndarray] = [] |
| | rollout_subgoal_grounded: list[Any] = [] |
| |
|
| | |
| | while True: |
| | replay_key = ACTION_SPACE |
| | action = dataset_resolver.get_step(replay_key, step) |
| | if ACTION_SPACE == "multi_choice": |
| | action = _parse_oracle_command(action) |
| | if action is None: |
| | break |
| |
|
| | candidate_pixels: list[list[int]] = [] |
| | clicked_pixel: Optional[list[int]] = None |
| | matched_pixel: Optional[list[int]] = None |
| | if ACTION_SPACE == "multi_choice": |
| | candidate_pixels, clicked_pixel, matched_pixel = _collect_multi_choice_visualization( |
| | env, action |
| | ) |
| |
|
| | |
| | |
| | obs, reward, terminated, truncated, info = env.step(action) |
| |
|
| | |
| | front_rgb_list = _to_frame_list(obs["front_rgb_list"]) |
| | wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"]) |
| | front_depth_list = obs["front_depth_list"] |
| | wrist_depth_list = obs["wrist_depth_list"] |
| | end_effector_pose_raw = obs["end_effector_pose_raw"] |
| | eef_state_list = obs["eef_state_list"] |
| | joint_state_list = obs["joint_state_list"] |
| | gripper_state_list = obs["gripper_state_list"] |
| | front_camera_extrinsic_list = obs["front_camera_extrinsic_list"] |
| | wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"] |
| |
|
| | |
| | task_goal = info["task_goal"] |
| | simple_subgoal_online = info["simple_subgoal_online"] |
| | grounded_subgoal_online = info["grounded_subgoal_online"] |
| | available_multi_choices = info.get("available_multi_choices") |
| | front_camera_intrinsic = info["front_camera_intrinsic"] |
| | wrist_camera_intrinsic = info["wrist_camera_intrinsic"] |
| | status = info.get("status") |
| |
|
| | |
| | rollout_base_frames.extend( |
| | _to_numpy_copy(f) for f in front_rgb_list |
| | ) |
| | rollout_wrist_frames.extend( |
| | _to_numpy_copy(f) for f in wrist_rgb_list |
| | ) |
| | if ACTION_SPACE == "multi_choice": |
| | for base_frame in front_rgb_list: |
| | rollout_right_frames.append( |
| | _draw_candidate_blackboard( |
| | base_frame, |
| | candidate_pixels=candidate_pixels, |
| | ) |
| | ) |
| | rollout_far_right_frames.append( |
| | _draw_selection_blackboard( |
| | base_frame, |
| | clicked_pixel=clicked_pixel, |
| | matched_pixel=matched_pixel, |
| | ) |
| | ) |
| | rollout_subgoal_grounded.extend([grounded_subgoal_online] * len(front_rgb_list)) |
| |
|
| | terminated_flag = bool(terminated.item()) |
| | truncated_flag = bool(truncated.item()) |
| |
|
| | step += 1 |
| | if GUI_RENDER: |
| | env.render() |
| | if truncated_flag: |
| | print(f"[{env_id}] episode {episode} steps exceeded, step {step}.") |
| | break |
| | if terminated_flag: |
| | if status == "success": |
| | print(f"[{env_id}] episode {episode} success.") |
| | episode_success = True |
| | elif status == "fail": |
| | print(f"[{env_id}] episode {episode} failed.") |
| | break |
| |
|
| | |
| | save_robomme_video( |
| | reset_base_frames=reset_base_frames, |
| | reset_wrist_frames=reset_wrist_frames, |
| | rollout_base_frames=rollout_base_frames, |
| | rollout_wrist_frames=rollout_wrist_frames, |
| | reset_subgoal_grounded=reset_subgoal_grounded, |
| | rollout_subgoal_grounded=rollout_subgoal_grounded, |
| | out_video_dir=OUT_VIDEO_DIR, |
| | action_space=ACTION_SPACE, |
| | env_id=env_id, |
| | episode=episode, |
| | episode_success=episode_success, |
| | reset_right_frames=reset_right_frames if ACTION_SPACE == "multi_choice" else None, |
| | rollout_right_frames=rollout_right_frames if ACTION_SPACE == "multi_choice" else None, |
| | reset_far_right_frames=( |
| | reset_far_right_frames if ACTION_SPACE == "multi_choice" else None |
| | ), |
| | rollout_far_right_frames=( |
| | rollout_far_right_frames if ACTION_SPACE == "multi_choice" else None |
| | ), |
| | ) |
| |
|
| | if env is not None: |
| | env.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|