RoboCasa_Env / env.py
Whalswp's picture
Upload folder using huggingface_hub
d421c67 verified
# env.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from collections import defaultdict
from collections.abc import Callable, Sequence, Mapping
from functools import partial
from typing import Any
# RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
OBS_STATE_DIM = 16
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
def convert_state(dict_state):
"""์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜(Conversion)ํ•ฉ๋‹ˆ๋‹ค."""
dict_state = dict_state.copy()
final_state = np.concatenate([
dict_state["state.base_position"],
dict_state["state.base_rotation"],
dict_state["state.end_effector_position_relative"],
dict_state["state.end_effector_rotation_relative"],
dict_state["state.gripper_qpos"],
], axis=0)
return final_state
def convert_action(action):
"""LeRobot์˜ ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
action = action.copy()
output_action = {
"action.base_motion": action[0:4],
"action.control_mode": action[4:5],
"action.end_effector_position": action[5:8],
"action.end_effector_rotation": action[8:11],
"action.gripper_close": action[11:12],
}
return output_action
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""์นด๋ฉ”๋ผ ์ด๋ฆ„์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ •๊ทœํ™”(Normalization)ํ•ฉ๋‹ˆ๋‹ค."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
if not cams:
raise ValueError("camera_name resolved to an empty list.")
return cams
class RoboCasaEnv(RoboCasaGymEnv):
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
def __init__(
self,
task: str,
camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
render_mode: str = "rgb_array",
obs_type: str = "pixels_agent_pos",
observation_width: int = 256,
observation_height: int = 256,
split: str | None = None,
**kwargs
):
self.obs_type = obs_type
self.render_mode = render_mode
self.split = split
self.task = task
kwargs.pop("fps", None)
self.kwargs = kwargs
meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
try:
self._max_episode_steps = meta_info[task]['horizon']
except KeyError:
raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
super().__init__(
task,
camera_names=camera_name,
camera_widths=observation_width,
camera_heights=observation_height,
enable_render=(render_mode is not None),
split=split,
**kwargs
)
def _create_obs_and_action_space(self):
images = {}
for cam in self.camera_names:
images[cam] = spaces.Box(
low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
)
if self.obs_type == "state":
raise NotImplementedError("The 'state' observation type is not supported.")
elif self.obs_type == "pixels":
self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict({
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
})
else:
raise ValueError(f"Unknown obs_type: {self.obs_type}")
self.action_space = spaces.Box(
low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
)
def reset(self, seed: int | None = None, **kwargs):
self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
observation, info = super().reset(seed, **kwargs)
return self._format_raw_obs(observation), info
def _format_raw_obs(self, raw_obs: dict):
new_obs = {}
if self.obs_type == "pixels_agent_pos":
new_obs["agent_pos"] = convert_state(raw_obs)
new_obs["pixels"] = {}
for k, v in raw_obs.items():
if "video." in k:
new_obs["pixels"][k.replace("video.", "")] = v
return new_obs
def step(self, action: np.ndarray):
self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
action_dict = convert_action(action)
observation, reward, done, truncated, info = super().step(action_dict)
new_obs = self._format_raw_obs(observation)
is_success = bool(info.get("success", 0))
terminated = done or is_success
info.update({"task": self.task, "done": done, "is_success": is_success})
if terminated:
info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
self.reset()
return new_obs, reward, terminated, truncated, info
def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
def _make_env(episode_index: int, **kwargs):
seed = kwargs.pop("seed", episode_index)
return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
# ======================================================================
# LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
# ======================================================================
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
"""
LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
"""
# ํ™˜๊ฒฝ ๋ž˜ํผ ํด๋ž˜์Šค ์„ ํƒ
env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
# ์„ค์ •๊ฐ’ ์ถ”์ถœ (cfg ๊ฐ์ฒด๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ ์šฉ)
if cfg is not None:
task_name = getattr(cfg, "task", "CloseFridge")
fps = getattr(cfg, "fps", 20) # fps ์ถ”์ถœ
gym_kwargs = {
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
"render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์œ ์ง€
"observation_width": getattr(cfg, "observation_width", 256),
"observation_height": getattr(cfg, "observation_height", 256),
"camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
"split": getattr(cfg, "split", None),
"fps": fps, # ํ•ต์‹ฌ ์ธ์ž ๋ˆ„๋ฝ ๋ฐฉ์ง€
}
else:
# cfg ์—†์ด ์ง์ ‘ ํ˜ธ์ถœ๋  ๋•Œ์˜ ๊ธฐ๋ณธ๊ฐ’
task_name = "CloseFridge"
gym_kwargs = {
"obs_type": "pixels_agent_pos",
"render_mode": "rgb_array",
"observation_width": 256,
"observation_height": 256,
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
"split": None,
}
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
# ๋ฒค์น˜๋งˆํฌ์ธ์ง€ ๋‹จ์ผ ํƒœ์Šคํฌ์ธ์ง€ ๊ตฌ๋ถ„
if task_name in combined_tasks:
task_names = combined_tasks[task_name]
gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
else:
task_names = [t.strip() for t in task_name.split(",")]
out = defaultdict(dict)
# ํƒœ์Šคํฌ๋ณ„๋กœ ํ™˜๊ฒฝ ์ƒ์„ฑ
for task in task_names:
fns = _make_env_fns(
task_name=task,
n_envs=n_envs,
camera_names=parsed_camera_names,
gym_kwargs=gym_kwargs
)
out[task][0] = env_cls(fns)
# {suite_name: {task_id: VectorEnv}} ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
#return {"robocasa": dict(out)}
return {suite: dict(task_map) for suite, task_map in out.items()}