Spaces:
Running
on
Zero
Running
on
Zero
# -------------------------------------------------------- | |
# Licensed under The MIT License [see LICENSE for details] | |
# -------------------------------------------------------- | |
""" | |
TODO: explain | |
""" | |
import h5py | |
import numpy as np | |
import cv2 | |
import time | |
from collections import OrderedDict | |
import robomimic.utils.file_utils as FileUtils | |
from sim.robomimic.robomimic_runner import ( | |
create_env, OBS_KEYS, RESOLUTION | |
) | |
from sim.robomimic.robomimic_wrapper import RobomimicLowdimWrapper | |
from typing import Optional, Iterable | |
DATASET_DIR = 'data/robomimic/datasets' | |
SUPPORTED_ENVS = ['lift', 'square', 'can'] | |
NUM_EPISODES_PER_TASK = 200 | |
def render_step(env, state): | |
env.env.env.sim.set_state_from_flattened(state) | |
env.env.env.sim.forward() | |
img = env.render() | |
img = cv2.resize(img, RESOLUTION) | |
return img | |
def robomimic_dataset_size() -> int: | |
return len(SUPPORTED_ENVS) * NUM_EPISODES_PER_TASK | |
def robomimic_dataset_generator(example_inds: Optional[Iterable[int]] = None): | |
if example_inds is None: | |
example_inds = range(robomimic_dataset_size()) | |
curr_env_name = None | |
for idx in example_inds: | |
# get env_name corresponding to idx | |
env_name = SUPPORTED_ENVS[idx // NUM_EPISODES_PER_TASK] | |
if curr_env_name is None or curr_env_name != env_name: | |
# need to load new env | |
dataset = f"{DATASET_DIR}/{env_name}/ph/image.hdf5" | |
env_meta = FileUtils.get_env_metadata_from_dataset(dataset) | |
env_meta["use_image_obs"] = True | |
env = create_env(env_meta=env_meta, obs_keys=OBS_KEYS) | |
env = RobomimicLowdimWrapper(env=env) | |
env.reset() # NOTE: this is necessary to remove green laser bug | |
curr_env_name = env_name | |
with h5py.File(dataset) as file: | |
demos = file["data"] | |
local_episode_idx = idx % NUM_EPISODES_PER_TASK | |
if f"demo_{local_episode_idx}" not in demos: | |
continue | |
demo = demos[f"demo_{local_episode_idx}"] | |
obs = demo["obs"] | |
states = demo["states"] | |
action = demo["actions"][:].astype(np.float32) | |
step_obs = np.concatenate([obs[key] for key in OBS_KEYS], axis=-1).astype(np.float32) | |
steps = [] | |
for a, o, s in zip(action, step_obs, states): | |
# break into step dict | |
image = render_step(env, s) | |
step = { | |
"observation": {"state": o, "image": image}, | |
"action": a, | |
"language_instruction": f"{env_name}", | |
} | |
steps.append(OrderedDict(step)) | |
data_dict = {"steps": steps} | |
yield data_dict | |
# # import imageio | |
# for _ in range(3): | |
# steps = [] | |
# perturbed_action = action + np.random.normal(0, 0.2, action.shape) | |
# current_state = states[0] | |
# _ = render_step(env, current_state) | |
# for someindex in range(len(action)): | |
# image = env.render() | |
# step = { | |
# "observation": {"image": image}, | |
# "action": action[someindex], | |
# "language_instruction": f"{env_name}", | |
# } | |
# steps.append(OrderedDict(step)) | |
# # simulate action | |
# env.step(perturbed_action[someindex]) | |
# # # save video | |
# # frames = [step["observation"]["image"] for step in steps] | |
# # imageio.mimsave(f"test.mp4", frames, fps=10) | |
# # while not (user_input := input("Continue? (y/n)")) in ["y", "n"]: | |
# # print("Invalid input") | |
# data_dict = {"steps": steps} | |
# yield data_dict | |
env.close() |