|
|
| import dill |
| import h5py |
| import numpy as np |
| import torch |
|
|
| def get_text_tokens(cfg, tokenizer, text_model, goal, model=None): |
| """ |
| Get the text tokens/embeddings for the goal. |
| If a `model` with `encode_text_goal` is provided, use it so callers don't need a buffer. |
| """ |
| if model is not None: |
| return model.encode_text_goal(goal, tokenizer=tokenizer, text_model=text_model) |
| |
| if cfg.dataset.encode_with_t5: |
| goal_ = np.zeros((cfg.max_block_size, cfg.n_embd), dtype=np.float32) |
| input_ids = tokenizer(goal, return_tensors="pt").input_ids |
| goal_t = text_model.encoder(input_ids).last_hidden_state.detach().cpu().numpy() |
| goal_[:len(goal_t[0]), :] = goal_t[0][:cfg.max_block_size] |
| else: |
| goal_ = " " * cfg.max_block_size |
| goal_ = goal[:cfg.max_block_size] + goal_[len(goal):cfg.max_block_size] |
| |
| raise RuntimeError("Text encoding without model requires a buffer; pass model into get_text_tokens") |
| return np.expand_dims(goal_, axis=0) |
|
|
| def get_blocked_mask(cfg, targets=None, T=0): |
| |
| c=192 |
| mask = torch.ones((1 + (c * cfg.policy.obs_stacking) + T + c, ), device=cfg.device) |
| if targets is None: |
| pass |
| elif (torch.rand(1)[0] > 0.66): |
| mask[1 + (c * cfg.policy.obs_stacking): 1 + (c * cfg.policy.obs_stacking) + T] = torch.zeros((1,T), device=cfg.device) |
| elif (torch.rand(1)[0] > 0.33): |
| mask[1 + (c * cfg.policy.obs_stacking) + T: 1 + (c * cfg.policy.obs_stacking) + T + c] = torch.zeros((1,c), device=cfg.device) |
|
|
| def eval_model_in_sim(cfg, model, device, log_dir, env, env_unwrapped, |
| wandb, iter_, tokenizer=None, text_model=None): |
| from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict |
| print("Evaluating model in sim environment") |
| from collections import deque |
| from einops import rearrange |
|
|
| rewards = [] |
| for j in range(cfg.sim.eval_episodes): |
| obs, reset_info = env.reset() |
| obs_ = get_image_from_maniskill2_obs_dict(env_unwrapped, obs)[:,:,:3] |
| obs_hist = deque(maxlen=cfg.policy.obs_stacking) |
| last_action = np.zeros(cfg.action_dim) |
| for _ in range(cfg.policy.obs_stacking): |
| obs_hist.append(obs_) |
| instruction = env_unwrapped.get_language_instruction() |
| |
| print("Instruction", instruction) |
| frames = [] |
| done, truncated, timeLimit, t = False, False, 100, 0 |
| txt_goal = get_text_tokens(cfg, tokenizer, text_model, instruction, model=model) |
| |
| while not (done or truncated or (t > timeLimit)): |
| |
| |
| |
| image = np.stack(obs_hist, axis=-1) |
| image = rearrange(image, 'h w c t -> h w (c t)') |
|
|
| obs_state = torch.tensor(model.preprocess_state(image), dtype=torch.float32) |
| goal_state = torch.tensor(model.preprocess_goal_image(image[:,:,:3]), dtype=torch.float32) |
| |
| |
| last_action_tensor = None |
| if last_action is not None: |
| last_action_tensor = torch.tensor(last_action[:cfg.action_dim], dtype=torch.float32).unsqueeze(0).to(device) |
| |
| action, loss = model.forward(torch.tensor(obs_state.unsqueeze(0), dtype=torch.float32).to(device) |
| ,torch.tensor(txt_goal).to(device) |
| ,torch.tensor(goal_state.unsqueeze(0), dtype=torch.float32).to(device), |
| mask_=True, |
| pose=torch.tensor([[obs["extra"]["tcp_pose"]]], dtype=torch.float32).to(device), |
| last_action=last_action_tensor, |
| ) |
|
|
| action = model.decode_action(action[0]).cpu().detach().numpy() |
| last_action = action.copy() |
| |
| for step_ in range(cfg.policy.action_stacking): |
| act_ = action[cfg.action_dim*step_:(cfg.action_dim*(step_+1))] |
| obs, reward, done, truncated, info = env.step(act_) |
| image = get_image_from_maniskill2_obs_dict(env_unwrapped, obs) |
| image = image[:,:,:3] |
| |
| frames.append(image) |
| reward = -(np.linalg.norm(info["eof_to_obj1_diff"]) + np.linalg.norm(info["eof_to_obj1_diff"])) |
| rewards.append(reward) |
| t=t+1 |
| if done or truncated: |
| break |
| |
| |
| episode_stats = info.get('episode_stats', {}) |
| episode_stats['rewards'] = np.mean(rewards) |
| |
| print(f"avg reward {np.mean(episode_stats['rewards']):.8f}") |
| if not cfg.testing: |
| wandb.log({"avg reward": np.mean(rewards)}) |
| |
| import os |
| path_ = os.path.join(log_dir, f"simple-env-{iter_}.mp4") |
| import imageio |
| imageio.mimsave(path_, frames, fps=20) |
| episode_stats['video_url'] = path_ |
|
|
| if not cfg.testing: |
| try: |
| wandb.log({"example": wandb.Video(path_)}) |
| except Exception as e: |
| print(f"Warning: failed to log video to wandb: {e}") |
|
|
| return episode_stats |
|
|
| import gymnasium as gym |
| |
| class DictWrapper(gym.ObservationWrapper): |
| |
| """ |
| A wrapper that grabs the observation from a specific key in the dictionary. |
| """ |
| def __init__(self, env, obs_key=""): |
| |
| self.env = env |
| self.observation_space = gym.spaces.Box( |
| low=0, |
| high=255, |
| shape=(256,256,3), |
| dtype=np.uint8) |
| self._obs_key = obs_key |
|
|
| def observation(self, observation): |
| """ |
| This method is called by the gym.ObservationWrapper after the environment's |
| step or reset methods return an observation. |
| """ |
| |
| return observation[self._obs_key] |
| |
| def step(self, action): |
| """ |
| Step the environment and return the observation from the specified key. |
| """ |
| obs, reward, done, info = self.env.step(action) |
| return obs[self._obs_key][::-1, :, :], reward, done, False, obs |
|
|
| def reset(self, **kwargs): |
| """ |
| Reset the environment and return the observation from the specified key. |
| """ |
| obs = self.env.reset() |
| return obs[self._obs_key][::-1, :, :], obs |
|
|
| def eval_libero(model, device, cfg, iter_=0, log_dir="./", |
| tokenizer=None, text_model=None, wandb=None): |
| |
| |
| |
| from libero.libero import benchmark |
| from libero.libero.envs import OffScreenRenderEnv, DenseRewardEnv |
| import os |
| from libero.libero.utils import get_libero_path |
| from gymnasium.wrappers import FrameStackObservation |
| from einops import rearrange |
|
|
| benchmark_dict = benchmark.get_benchmark_dict() |
| task_suite_name = cfg.sim.task_set |
| task_suite = benchmark_dict[task_suite_name]() |
| |
| |
| init_states_dataset = None |
| if hasattr(cfg.sim, 'libero_init_state_hf_repo') and cfg.sim.libero_init_state_hf_repo: |
| print(f"Loading initial states from Hugging Face: {cfg.sim.libero_init_state_hf_repo}") |
| from datasets import load_dataset |
| init_states_dataset = load_dataset(cfg.sim.libero_init_state_hf_repo, split='train') |
| print(f"Loaded dataset with {len(init_states_dataset)} entries") |
| elif hasattr(cfg.sim, 'libero_init_state_file') and cfg.sim.libero_init_state_file: |
| print(f"Loading initial states from HDF5: {cfg.sim.libero_init_state_file}") |
| init_states_dataset = h5py.File(hydra.utils.get_original_cwd()+cfg.sim.libero_init_state_file, 'r') |
|
|
| |
| tasks = cfg.sim.eval_tasks |
| for task_id in tasks: |
| task = task_suite.get_task(task_id) |
| task_name = task.name |
| instruction = task.language |
| task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) |
| print(f"[info] retrieving task {task_id} from suite {task_suite_name}, the " + \ |
| f"language instruction is {instruction}, and the bddl file is {task_bddl_file}") |
|
|
| |
| env_args = { |
| "bddl_file_name": task_bddl_file, |
| "camera_heights": 256, |
| "camera_widths": 256 |
| } |
| env = DenseRewardEnv(**env_args) |
| env.seed(0) |
| |
| |
| task_description = instruction.replace(" ", "_") |
| task_demos = None |
| if init_states_dataset is not None: |
| if isinstance(init_states_dataset, h5py.File): |
| |
| if task_description in init_states_dataset: |
| task_grp = init_states_dataset[task_description] |
| num_init_states = len(task_grp.keys()) |
| print(f"Loaded {num_init_states} initial states from HDF5 for task: {task_description}") |
| else: |
| task_grp = None |
| init_states = task_suite.get_task_init_states(task_id) |
| num_init_states = len(init_states) |
| print(f"Using default initial states for task: {task_description}") |
| else: |
| |
| task_demos = [item for item in init_states_dataset if item.get('task_description') == task_description] |
| num_init_states = len(task_demos) |
| if num_init_states > 0: |
| print(f"Loaded {num_init_states} initial states from HF dataset for task: {task_description}") |
| else: |
| init_states = task_suite.get_task_init_states(task_id) |
| num_init_states = len(init_states) |
| print(f"Using default initial states for task: {task_description}") |
| else: |
| init_states = task_suite.get_task_init_states(task_id) |
| num_init_states = len(init_states) |
| print(f"Using default initial states for task: {task_description}") |
| |
| |
| for init_state_id in range(min(2, num_init_states)): |
| |
| if init_states_dataset is not None: |
| if isinstance(init_states_dataset, h5py.File): |
| |
| if task_grp is not None: |
| demo_key = f"demo_{init_state_id}" |
| if demo_key in task_grp: |
| init_state = task_grp[demo_key]['init_state'][()] |
| goal_img = task_grp[demo_key]['goal_img'][()] if 'goal_img' in task_grp[demo_key] else None |
| print(f"Loaded init_state and goal_img from HDF5 for {demo_key}") |
| else: |
| init_state = init_states[init_state_id] |
| goal_img = None |
| else: |
| init_state = init_states[init_state_id] |
| goal_img = None |
| else: |
| |
| if task_demos and init_state_id < len(task_demos): |
| demo = task_demos[init_state_id] |
| init_state = np.array(demo['init_state']) |
| goal_img = np.array(demo['goal_img']) if 'goal_img' in demo and demo['goal_img'] is not None else None |
| print(f"Loaded init_state and goal_img from HF dataset for demo {init_state_id}") |
| else: |
| init_state = init_states[init_state_id] |
| goal_img = None |
| else: |
| init_state = init_states[init_state_id] |
| goal_img = None |
| |
| env.reset() |
| env.set_init_state(init_state) |
| env_ = FrameStackObservation(DictWrapper(env, obs_key="agentview_image"), cfg.policy.obs_stacking) |
| obs, info = env_.reset() |
|
|
| mask = get_blocked_mask(cfg, targets=None, T=0) |
| |
| txt_goal = get_text_tokens(cfg, tokenizer, text_model, instruction, model=model) |
| |
| |
| if goal_img is not None: |
| image_goal = goal_img |
| print(f"Using goal image from HDF5, shape: {image_goal.shape}") |
| else: |
| image_goal = obs.reshape((256, 256, 3*cfg.policy.obs_stacking))[:,:,:3] |
| print("Using first observation as goal image") |
| frames = [] |
| rewards = [] |
| infos = [] |
| last_action = np.zeros(cfg.action_dim) |
| done, truncated, timeLimit, t, wait_steps = False, False, 400, 0, 00 |
|
|
| while not (done or truncated or (t > (timeLimit + wait_steps))): |
| |
| |
| if t < wait_steps: |
| obs, reward, done, truncated, info = env_.step([0,0,0,0,0,0,0]) |
| t += 1 |
| continue |
| |
| obs = rearrange(obs, 't h w c -> h w (t c)', c=3, t=cfg.policy.obs_stacking) |
| |
| obs_state = model.preprocess_state(obs) |
| goal_state = model.preprocess_goal_image(image_goal) |
| pose_ = model.encode_pose(torch.tensor([[np.concatenate( |
| (info["robot0_eef_pos"], |
| info["robot0_eef_quat"][:3], |
| [(info["robot0_gripper_qpos"][0])]), axis=-1)]], |
| dtype=torch.float32)).to(device) |
| |
| |
| last_action_tensor = None |
| if last_action is not None: |
| last_action_tensor = model.encode_action(torch.tensor([last_action[:cfg.action_dim]], dtype=torch.float32)).to(device) |
| |
| action, loss = model.forward(torch.tensor(np.array([obs_state])).to(device) |
| ,torch.tensor(txt_goal).to(device) |
| ,torch.tensor(np.array([goal_state])).to(device), |
| mask_=True, |
| pose=pose_, |
| last_action=last_action_tensor, |
| ) |
|
|
| action = model.decode_action(action[0]).cpu().detach().numpy() |
| last_action = action.copy() |
| |
| for step_ in range(cfg.policy.action_stacking): |
| act_ = action[cfg.action_dim*step_:(cfg.action_dim*(step_+1))] |
| |
| |
| |
| |
|
|
| obs, reward, done, truncated, info = env_.step(act_) |
| |
| |
| |
| image = obs[0] |
| frames.append(image) |
| |
| rewards.append(reward) |
| infos.append(info) |
| t=t+1 |
| |
| if done or truncated: |
| print("Episode finished with success after {} timesteps".format(step_)) |
| break |
| if done: |
| print("Episode finished with success after {} timesteps".format(step_)) |
| break |
| |
| import os |
| path_ = os.path.join(log_dir, f"libero-{iter_}-task-id-{task_id}-init-id-{init_state_id}.mp4") |
| import imageio |
| imageio.mimsave(path_, frames, fps=20) |
| episode_stats = info.get('episode_stats', {}) |
| episode_stats['rewards'] = np.mean(rewards) |
| episode_stats['video_url'] = path_ |
| print(f"avg reward {np.mean(rewards):.8f}") |
| if not cfg.testing: |
| wandb.log({"avg reward_"+str(task_id): np.mean(rewards)}) |
| if not cfg.testing: |
| wandb.log({"example": wandb.Video(path_)}) |
| env.close() |
| |
| |
| if init_states_dataset is not None and isinstance(init_states_dataset, h5py.File): |
| init_states_dataset.close() |
| print("Closed HDF5 file") |
| |
| return episode_stats |
|
|
| import hydra |
| from omegaconf import DictConfig |
|
|
| @hydra.main(config_path="./conf", config_name="64pix-pose") |
| def my_main(cfg: DictConfig): |
| import torch |
| |
| |
| |
| |
| log_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir |
| cfg.dataset.load_dataset = "skip" |
| |
| |
| model_dir = hydra.utils.get_original_cwd()+"/mini-grp/miniGRP.pth" |
| print ("Loading model from:", model_dir) |
| if "dataset" == cfg.model.type: |
| |
| from mini_shuffel_buffer import CircularBuffer |
| from mock_grp_model import ReplayModel |
| cfg.dataset.load_dataset = True |
| model_ = ReplayModel(cfg) |
| dataset_buffer = CircularBuffer(cfg.dataset.buffer_size, cfg, model=model_) |
| model_.set_dataset(dataset_buffer) |
| else: |
| from grp_model import GRP |
| model_ = torch.load(model_dir, pickle_module=dill) |
| |
|
|
| tokenizer = None |
| text_model = None |
| if cfg.dataset.encode_with_t5: |
| from transformers import T5Tokenizer, T5ForConditionalGeneration |
| tokenizer = T5Tokenizer.from_pretrained(cfg.dataset.t5_version) |
| text_model = T5ForConditionalGeneration.from_pretrained(cfg.dataset.t5_version) |
| |
| if "libero" in cfg.simEval: |
| results = eval_libero(model_.to(cfg.device), device=cfg.device, cfg=cfg, |
| iter_=0, tokenizer=tokenizer, text_model=text_model, wandb=None, |
| log_dir=log_dir) |
| if "simple_env" in cfg.simEval: |
| import simpler_env |
| task_name = "widowx_carrot_on_plate" |
| if 'env' in locals(): |
| print("Closing existing env") |
| env.close() |
| del env |
| env = simpler_env.make(task_name) |
| env_unwrapped = env.env.env.env |
| results = eval_model_in_sim(cfg, model_.to(cfg.device), device=cfg.device, log_dir=log_dir, |
| env=env, env_unwrapped=env_unwrapped, |
| wandb=None, iter_=0, tokenizer=tokenizer, text_model=text_model) |
| print("results:", results) |
|
|
| |
|
|
|
|
| if __name__ == "__main__": |
| results = my_main() |
| print("results:", results) |