| |
| import gymnasium as gym |
| from gymnasium import spaces |
| import numpy as np |
| from collections import defaultdict |
| from collections.abc import Callable, Sequence, Mapping |
| from functools import partial |
| from typing import Any |
|
|
| |
| from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv |
| from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS |
|
|
| OBS_STATE_DIM = 16 |
| ACTION_DIM = 12 |
| ACTION_LOW = -1.0 |
| ACTION_HIGH = 1.0 |
|
|
| def convert_state(dict_state): |
| """์๋ฎฌ๋ ์ดํฐ ์ํ๋ฅผ LeRobot์ด ๊ธฐ๋ํ๋ ํํ๋ก ๋ณํ(Conversion)ํฉ๋๋ค.""" |
| dict_state = dict_state.copy() |
| final_state = np.concatenate([ |
| dict_state["state.base_position"], |
| dict_state["state.base_rotation"], |
| dict_state["state.end_effector_position_relative"], |
| dict_state["state.end_effector_rotation_relative"], |
| dict_state["state.gripper_qpos"], |
| ], axis=0) |
| return final_state |
|
|
| def convert_action(action): |
| """LeRobot์ ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค.""" |
| action = action.copy() |
| output_action = { |
| "action.base_motion": action[0:4], |
| "action.control_mode": action[4:5], |
| "action.end_effector_position": action[5:8], |
| "action.end_effector_rotation": action[8:11], |
| "action.gripper_close": action[11:12], |
| } |
| return output_action |
|
|
| def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: |
| """์นด๋ฉ๋ผ ์ด๋ฆ์ ๋ฆฌ์คํธ ํํ๋ก ์ ๊ทํ(Normalization)ํฉ๋๋ค.""" |
| if isinstance(camera_name, str): |
| cams = [c.strip() for c in camera_name.split(",") if c.strip()] |
| elif isinstance(camera_name, (list, tuple)): |
| cams = [str(c).strip() for c in camera_name if str(c).strip()] |
| else: |
| raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") |
| if not cams: |
| raise ValueError("camera_name resolved to an empty list.") |
| return cams |
|
|
| class RoboCasaEnv(RoboCasaGymEnv): |
| metadata = {"render_modes": ["rgb_array"], "render_fps": 20} |
|
|
| def __init__( |
| self, |
| task: str, |
| camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"], |
| render_mode: str = "rgb_array", |
| obs_type: str = "pixels_agent_pos", |
| observation_width: int = 256, |
| observation_height: int = 256, |
| split: str | None = None, |
| **kwargs |
| ): |
| self.obs_type = obs_type |
| self.render_mode = render_mode |
| self.split = split |
| self.task = task |
| |
| kwargs.pop("fps", None) |
| self.kwargs = kwargs |
|
|
| meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS} |
| try: |
| self._max_episode_steps = meta_info[task]['horizon'] |
| except KeyError: |
| raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}") |
| |
| super().__init__( |
| task, |
| camera_names=camera_name, |
| camera_widths=observation_width, |
| camera_heights=observation_height, |
| enable_render=(render_mode is not None), |
| split=split, |
| **kwargs |
| ) |
| |
| def _create_obs_and_action_space(self): |
| images = {} |
| for cam in self.camera_names: |
| images[cam] = spaces.Box( |
| low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8 |
| ) |
| if self.obs_type == "state": |
| raise NotImplementedError("The 'state' observation type is not supported.") |
| elif self.obs_type == "pixels": |
| self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)}) |
| elif self.obs_type == "pixels_agent_pos": |
| self.observation_space = spaces.Dict({ |
| "pixels": spaces.Dict(images), |
| "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32), |
| }) |
| else: |
| raise ValueError(f"Unknown obs_type: {self.obs_type}") |
|
|
| self.action_space = spaces.Box( |
| low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32 |
| ) |
|
|
| def reset(self, seed: int | None = None, **kwargs): |
| self.unwrapped.sim._render_context_offscreen.gl_ctx.free() |
| observation, info = super().reset(seed, **kwargs) |
| return self._format_raw_obs(observation), info |
| |
| def _format_raw_obs(self, raw_obs: dict): |
| new_obs = {} |
| if self.obs_type == "pixels_agent_pos": |
| new_obs["agent_pos"] = convert_state(raw_obs) |
| new_obs["pixels"] = {} |
| for k, v in raw_obs.items(): |
| if "video." in k: |
| new_obs["pixels"][k.replace("video.", "")] = v |
| return new_obs |
|
|
| def step(self, action: np.ndarray): |
| self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current() |
| action_dict = convert_action(action) |
| observation, reward, done, truncated, info = super().step(action_dict) |
| new_obs = self._format_raw_obs(observation) |
|
|
| is_success = bool(info.get("success", 0)) |
| terminated = done or is_success |
| info.update({"task": self.task, "done": done, "is_success": is_success}) |
| |
| if terminated: |
| info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)} |
| self.reset() |
|
|
| return new_obs, reward, terminated, truncated, info |
|
|
|
|
| def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]): |
| def _make_env(episode_index: int, **kwargs): |
| seed = kwargs.pop("seed", episode_index) |
| return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs) |
|
|
| return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)] |
|
|
|
|
| |
| |
| |
| def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]: |
| """ |
| LeRobot์ด Hub์์ ํ๊ฒฝ์ ๋ก๋ํ ๋ ํธ์ถํ๋ ๋ฉ์ธ ํจ์์
๋๋ค. |
| """ |
| |
| env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv |
|
|
| |
| if cfg is not None: |
| task_name = getattr(cfg, "task", "CloseFridge") |
| fps = getattr(cfg, "fps", 20) |
| gym_kwargs = { |
| "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"), |
| "render_mode": getattr(cfg, "render_mode", "rgb_array"), |
| "observation_width": getattr(cfg, "observation_width", 256), |
| "observation_height": getattr(cfg, "observation_height", 256), |
| "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"), |
| "split": getattr(cfg, "split", None), |
| "fps": fps, |
| } |
| else: |
| |
| task_name = "CloseFridge" |
| gym_kwargs = { |
| "obs_type": "pixels_agent_pos", |
| "render_mode": "rgb_array", |
| "observation_width": 256, |
| "observation_height": 256, |
| "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right", |
| "split": None, |
| } |
|
|
| parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name")) |
| combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS} |
|
|
| |
| if task_name in combined_tasks: |
| task_names = combined_tasks[task_name] |
| gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain" |
| else: |
| task_names = [t.strip() for t in task_name.split(",")] |
| |
|
|
| out = defaultdict(dict) |
|
|
| |
| for task in task_names: |
| fns = _make_env_fns( |
| task_name=task, |
| n_envs=n_envs, |
| camera_names=parsed_camera_names, |
| gym_kwargs=gym_kwargs |
| ) |
| out[task][0] = env_cls(fns) |
|
|
| |
| |
| return {suite: dict(task_map) for suite, task_map in out.items()} |