Spaces:
Runtime error
Runtime error
| from collections import OrderedDict, deque | |
| from typing import Any, NamedTuple | |
| import os | |
| import dm_env | |
| import numpy as np | |
| from dm_env import StepType, specs | |
| import gym | |
| import torch | |
| class ExtendedTimeStep(NamedTuple): | |
| step_type: Any | |
| reward: Any | |
| discount: Any | |
| observation: Any | |
| action: Any | |
| def first(self): | |
| return self.step_type == StepType.FIRST | |
| def mid(self): | |
| return self.step_type == StepType.MID | |
| def last(self): | |
| return self.step_type == StepType.LAST | |
| def __getitem__(self, attr): | |
| return getattr(self, attr) | |
| class FlattenJacoObservationWrapper(dm_env.Environment): | |
| def __init__(self, env): | |
| self._env = env | |
| self._obs_spec = OrderedDict() | |
| wrapped_obs_spec = env.observation_spec().copy() | |
| if 'front_close' in wrapped_obs_spec: | |
| spec = wrapped_obs_spec['front_close'] | |
| # drop batch dim | |
| self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:], | |
| dtype=spec.dtype, | |
| minimum=spec.minimum, | |
| maximum=spec.maximum, | |
| name='pixels') | |
| wrapped_obs_spec.pop('front_close') | |
| for key, spec in wrapped_obs_spec.items(): | |
| assert spec.dtype == np.float64 | |
| assert type(spec) == specs.Array | |
| dim = np.sum( | |
| np.fromiter((int(np.prod(spec.shape)) | |
| for spec in wrapped_obs_spec.values()), np.int32)) | |
| self._obs_spec['observations'] = specs.Array(shape=(dim,), | |
| dtype=np.float32, | |
| name='observations') | |
| def _transform_observation(self, time_step): | |
| obs = OrderedDict() | |
| if 'front_close' in time_step.observation: | |
| pixels = time_step.observation['front_close'] | |
| time_step.observation.pop('front_close') | |
| pixels = np.squeeze(pixels) | |
| obs['pixels'] = pixels | |
| features = [] | |
| for feature in time_step.observation.values(): | |
| features.append(feature.ravel()) | |
| obs['observations'] = np.concatenate(features, axis=0) | |
| return time_step._replace(observation=obs) | |
| def reset(self): | |
| time_step = self._env.reset() | |
| return self._transform_observation(time_step) | |
| def step(self, action): | |
| time_step = self._env.step(action) | |
| return self._transform_observation(time_step) | |
| def observation_spec(self): | |
| return self._obs_spec | |
| def action_spec(self): | |
| return self._env.action_spec() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class ActionRepeatWrapper(dm_env.Environment): | |
| def __init__(self, env, num_repeats): | |
| self._env = env | |
| self._num_repeats = num_repeats | |
| def step(self, action): | |
| reward = 0.0 | |
| discount = 1.0 | |
| for i in range(self._num_repeats): | |
| time_step = self._env.step(action) | |
| reward += (time_step.reward or 0.0) * discount | |
| discount *= time_step.discount | |
| if time_step.last(): | |
| break | |
| return time_step._replace(reward=reward, discount=discount) | |
| def observation_spec(self): | |
| return self._env.observation_spec() | |
| def action_spec(self): | |
| return self._env.action_spec() | |
| def reset(self): | |
| return self._env.reset() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class FramesWrapper(dm_env.Environment): | |
| def __init__(self, env, num_frames=1, pixels_key='pixels'): | |
| self._env = env | |
| self._num_frames = num_frames | |
| self._frames = deque([], maxlen=num_frames) | |
| self._pixels_key = pixels_key | |
| wrapped_obs_spec = env.observation_spec() | |
| assert pixels_key in wrapped_obs_spec | |
| pixels_shape = wrapped_obs_spec[pixels_key].shape | |
| # remove batch dim | |
| if len(pixels_shape) == 4: | |
| pixels_shape = pixels_shape[1:] | |
| self._obs_spec = specs.BoundedArray(shape=np.concatenate( | |
| [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), | |
| dtype=np.uint8, | |
| minimum=0, | |
| maximum=255, | |
| name='observation') | |
| def _transform_observation(self, time_step): | |
| assert len(self._frames) == self._num_frames | |
| obs = np.concatenate(list(self._frames), axis=0) | |
| return time_step._replace(observation=obs) | |
| def _extract_pixels(self, time_step): | |
| pixels = time_step.observation[self._pixels_key] | |
| # remove batch dim | |
| if len(pixels.shape) == 4: | |
| pixels = pixels[0] | |
| return pixels.transpose(2, 0, 1).copy() | |
| def reset(self): | |
| time_step = self._env.reset() | |
| pixels = self._extract_pixels(time_step) | |
| for _ in range(self._num_frames): | |
| self._frames.append(pixels) | |
| return self._transform_observation(time_step) | |
| def step(self, action): | |
| time_step = self._env.step(action) | |
| pixels = self._extract_pixels(time_step) | |
| self._frames.append(pixels) | |
| return self._transform_observation(time_step) | |
| def observation_spec(self): | |
| return self._obs_spec | |
| def action_spec(self): | |
| return self._env.action_spec() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class OneHotAction(gym.Wrapper): | |
| def __init__(self, env): | |
| assert isinstance(env.action_space, gym.spaces.Discrete) | |
| super().__init__(env) | |
| self._random = np.random.RandomState() | |
| shape = (self.env.action_space.n,) | |
| space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) | |
| space.discrete = True | |
| self.action_space = space | |
| def step(self, action): | |
| index = np.argmax(action).astype(int) | |
| reference = np.zeros_like(action) | |
| reference[index] = 1 | |
| if not np.allclose(reference, action): | |
| raise ValueError(f"Invalid one-hot action:\n{action}") | |
| return self.env.step(index) | |
| def reset(self): | |
| return self.env.reset() | |
| def _sample_action(self): | |
| actions = self.env.action_space.n | |
| index = self._random.randint(0, actions) | |
| reference = np.zeros(actions, dtype=np.float32) | |
| reference[index] = 1.0 | |
| return reference | |
| class ActionDTypeWrapper(dm_env.Environment): | |
| def __init__(self, env, dtype): | |
| self._env = env | |
| wrapped_action_spec = env.action_spec() | |
| self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, | |
| dtype, | |
| wrapped_action_spec.minimum, | |
| wrapped_action_spec.maximum, | |
| 'action') | |
| def step(self, action): | |
| action = action.astype(self._env.action_spec().dtype) | |
| return self._env.step(action) | |
| def observation_spec(self): | |
| return self._env.observation_spec() | |
| def action_spec(self): | |
| return self._action_spec | |
| def reset(self): | |
| return self._env.reset() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class ObservationDTypeWrapper(dm_env.Environment): | |
| def __init__(self, env, dtype): | |
| self._env = env | |
| self._dtype = dtype | |
| wrapped_obs_spec = env.observation_spec()['observations'] | |
| self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype, | |
| 'observation') | |
| def _transform_observation(self, time_step): | |
| obs = time_step.observation['observations'].astype(self._dtype) | |
| return time_step._replace(observation=obs) | |
| def reset(self): | |
| time_step = self._env.reset() | |
| return self._transform_observation(time_step) | |
| def step(self, action): | |
| time_step = self._env.step(action) | |
| return self._transform_observation(time_step) | |
| def observation_spec(self): | |
| return self._obs_spec | |
| def action_spec(self): | |
| return self._env.action_spec() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class ExtendedTimeStepWrapper(dm_env.Environment): | |
| def __init__(self, env): | |
| self._env = env | |
| def reset(self): | |
| time_step = self._env.reset() | |
| return self._augment_time_step(time_step) | |
| def step(self, action): | |
| time_step = self._env.step(action) | |
| return self._augment_time_step(time_step, action) | |
| def _augment_time_step(self, time_step, action=None): | |
| if action is None: | |
| action_spec = self.action_spec() | |
| action = np.zeros(action_spec.shape, dtype=action_spec.dtype) | |
| return ExtendedTimeStep(observation=time_step.observation, | |
| step_type=time_step.step_type, | |
| action=action, | |
| reward=time_step.reward or 0.0, | |
| discount=time_step.discount or 1.0) | |
| def observation_spec(self): | |
| return self._env.observation_spec() | |
| def action_spec(self): | |
| return self._env.action_spec() | |
| def __getattr__(self, name): | |
| return getattr(self._env, name) | |
| class DMC: | |
| def __init__(self, env): | |
| self._env = env | |
| self._ignored_keys = [] | |
| def step(self, action): | |
| time_step = self._env.step(action) | |
| assert time_step.discount in (0, 1) | |
| obs = { | |
| 'reward': time_step.reward, | |
| 'is_first': False, | |
| 'is_last': time_step.last(), | |
| 'is_terminal': time_step.discount == 0, | |
| 'observation': time_step.observation, | |
| 'action' : action, | |
| 'discount': time_step.discount | |
| } | |
| return time_step, obs | |
| def reset(self): | |
| time_step = self._env.reset() | |
| obs = { | |
| 'reward': 0.0, | |
| 'is_first': True, | |
| 'is_last': False, | |
| 'is_terminal': False, | |
| 'observation': time_step.observation, | |
| 'action' : np.zeros_like(self.act_space['action'].sample()), | |
| 'discount': time_step.discount | |
| } | |
| return time_step, obs | |
| def __getattr__(self, name): | |
| if name == 'obs_space': | |
| obs_spaces = { | |
| 'observation': self._env.observation_spec(), | |
| 'is_first': gym.spaces.Box(0, 1, (), dtype=bool), | |
| 'is_last': gym.spaces.Box(0, 1, (), dtype=bool), | |
| 'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool), | |
| } | |
| return obs_spaces | |
| if name == 'act_space': | |
| spec = self._env.action_spec() | |
| action = gym.spaces.Box((spec.minimum)*spec.shape[0], (spec.maximum)*spec.shape[0], shape=spec.shape, dtype=np.float32) | |
| act_space = {'action': action} | |
| return act_space | |
| return getattr(self._env, name) | |
| class OneHotAction(gym.Wrapper): | |
| def __init__(self, env): | |
| assert isinstance(env.action_space, gym.spaces.Discrete) | |
| super().__init__(env) | |
| self._random = np.random.RandomState() | |
| shape = (self.env.action_space.n,) | |
| space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) | |
| space.discrete = True | |
| self.action_space = space | |
| def step(self, action): | |
| index = np.argmax(action).astype(int) | |
| reference = np.zeros_like(action) | |
| reference[index] = 1 | |
| if not np.allclose(reference, action): | |
| raise ValueError(f"Invalid one-hot action:\n{action}") | |
| return self.env.step(index) | |
| def reset(self): | |
| return self.env.reset() | |
| def _sample_action(self): | |
| actions = self.env.action_space.n | |
| index = self._random.randint(0, actions) | |
| reference = np.zeros(actions, dtype=np.float32) | |
| reference[index] = 1.0 | |
| return reference | |
| class KitchenWrapper: | |
| def __init__( | |
| self, | |
| name, | |
| seed=0, | |
| action_repeat=1, | |
| size=(64, 64), | |
| ): | |
| import envs.kitchen_extra as kitchen_extra | |
| self._env = { | |
| 'microwave' : kitchen_extra.KitchenMicrowaveV0, | |
| 'kettle' : kitchen_extra.KitchenKettleV0, | |
| 'burner' : kitchen_extra.KitchenBurnerV0, | |
| 'light' : kitchen_extra.KitchenLightV0, | |
| 'hinge' : kitchen_extra.KitchenHingeV0, | |
| 'slide' : kitchen_extra.KitchenSlideV0, | |
| 'top_burner' : kitchen_extra.KitchenTopBurnerV0, | |
| }[name]() | |
| self._size = size | |
| self._action_repeat = action_repeat | |
| self._seed = seed | |
| self._eval = False | |
| def eval_mode(self,): | |
| self._env.dense = False | |
| self._eval = True | |
| def obs_space(self): | |
| spaces = { | |
| "observation": gym.spaces.Box(0, 255, (3,) + self._size, dtype=np.uint8), | |
| "is_first": gym.spaces.Box(0, 1, (), dtype=bool), | |
| "is_last": gym.spaces.Box(0, 1, (), dtype=bool), | |
| "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), | |
| "state": self._env.observation_space, | |
| } | |
| return spaces | |
| def act_space(self): | |
| action = self._env.action_space | |
| return {"action": action} | |
| def step(self, action): | |
| # assert np.isfinite(action["action"]).all(), action["action"] | |
| reward = 0.0 | |
| for _ in range(self._action_repeat): | |
| state, rew, done, info = self._env.step(action.copy()) | |
| reward += rew | |
| obs = { | |
| "reward": reward, | |
| "is_first": False, | |
| "is_last": False, # will be handled by timelimit wrapper | |
| "is_terminal": False, # will be handled by per_episode function | |
| "observation": info['images'].transpose(2, 0, 1).copy(), | |
| "state": state.astype(np.float32), | |
| 'action' : action, | |
| 'discount' : 1 | |
| } | |
| if self._eval: | |
| obs['reward'] = min(obs['reward'], 1) | |
| if obs['reward'] > 0: | |
| obs['is_last'] = True | |
| return dm_env.TimeStep( | |
| step_type=dm_env.StepType.MID if not obs['is_last'] else dm_env.StepType.LAST, | |
| reward=obs['reward'], | |
| discount=1, | |
| observation=obs['observation']), obs | |
| def reset(self,): | |
| state = self._env.reset() | |
| obs = { | |
| "reward": 0.0, | |
| "is_first": True, | |
| "is_last": False, | |
| "is_terminal": False, | |
| "observation": self.get_visual_obs(self._size), | |
| "state": state.astype(np.float32), | |
| 'action' : np.zeros_like(self.act_space['action'].sample()), | |
| 'discount' : 1 | |
| } | |
| return dm_env.TimeStep( | |
| step_type=dm_env.StepType.FIRST, | |
| reward=None, | |
| discount=None, | |
| observation=obs['observation']), obs | |
| def __getattr__(self, name): | |
| if name == 'obs_space': | |
| return self.obs_space | |
| if name == 'act_space': | |
| return self.act_space | |
| return getattr(self._env, name) | |
| def get_visual_obs(self, resolution): | |
| img = self._env.render(resolution=resolution,).transpose(2, 0, 1).copy() | |
| return img | |
| class ViClipWrapper: | |
| def __init__(self, env, hd_rendering=False, device='cuda'): | |
| self._env = env | |
| try: | |
| from tools.genrl_utils import viclip_global_instance | |
| except: | |
| from tools.genrl_utils import ViCLIPGlobalInstance | |
| viclip_global_instance = ViCLIPGlobalInstance() | |
| if not viclip_global_instance._instantiated: | |
| viclip_global_instance.instantiate(device) | |
| self.viclip_model = viclip_global_instance.viclip | |
| self.n_frames = self.viclip_model.n_frames | |
| self.viclip_emb_dim = viclip_global_instance.viclip_emb_dim | |
| self.n_frames = self.viclip_model.n_frames | |
| self.buffer = deque(maxlen=self.n_frames) | |
| # NOTE: these are hardcoded for now, as they are the best settings | |
| self.accumulate = True | |
| self.accumulate_buffer = [] | |
| self.anticipate_conv1 = False | |
| self.hd_rendering = hd_rendering | |
| def hd_render(self, obs): | |
| if not self.hd_rendering: | |
| return obs['observation'] | |
| if self._env._domain_name in ['mw', 'kitchen', 'mujoco']: | |
| return self.get_visual_obs((224,224,)) | |
| else: | |
| render_kwargs = {**getattr(self, '_render_kwargs', {})} | |
| render_kwargs.update({'width' : 224, 'height' : 224}) | |
| return self._env.physics.render(**render_kwargs).transpose(2,0,1) | |
| def preprocess(self, x): | |
| return x | |
| def process_accumulate(self, process_at_once=4): # NOTE: this could be varied for increasing FPS, depending on the size of the GPU | |
| self.accumulate = False | |
| x = np.stack(self.accumulate_buffer, axis=0) | |
| # Splitting in chunks | |
| chunks = [] | |
| chunk_idxs = list(range(0, x.shape[0] + 1, process_at_once)) | |
| if chunk_idxs[-1] != x.shape[0]: | |
| chunk_idxs.append(x.shape[0]) | |
| start = 0 | |
| for end in chunk_idxs[1:]: | |
| embeds = self.clip_process(x[start:end], bypass=True) | |
| chunks.append(embeds.cpu()) | |
| start = end | |
| embeds = torch.cat(chunks, dim=0) | |
| assert embeds.shape[0] == len(self.accumulate_buffer) | |
| self.accumulate = True | |
| self.accumulate_buffer = [] | |
| return [*embeds.cpu().numpy()], 'clip_video' | |
| def process_episode(self, obs, process_at_once=8): | |
| self.accumulate = False | |
| sequences = [] | |
| for j in range(obs.shape[0] - self.n_frames + 1): | |
| sequences.append(obs[j:j+self.n_frames].copy()) | |
| sequences = np.stack(sequences, axis=0) | |
| idx_start = 0 | |
| clip_vid = [] | |
| for idx_end in range(process_at_once, sequences.shape[0] + process_at_once, process_at_once): | |
| x = sequences[idx_start:idx_end] | |
| with torch.no_grad(): # , torch.cuda.amp.autocast(): | |
| x = self.clip_process(x, bypass=True) | |
| clip_vid.append(x) | |
| idx_start = idx_end | |
| if len(clip_vid) == 1: # process all at once | |
| embeds = clip_vid[0] | |
| else: | |
| embeds = torch.cat(clip_vid, dim=0) | |
| pad = torch.zeros( (self.n_frames - 1, *embeds.shape[1:]), device=embeds.device, dtype=embeds.dtype) | |
| embeds = torch.cat([pad, embeds], dim=0) | |
| assert embeds.shape[0] == obs.shape[0], f"Shapes are different {embeds.shape[0]} {obs.shape[0]}" | |
| return embeds.cpu().numpy() | |
| def get_sequence(self,): | |
| return np.expand_dims(np.stack(self.buffer, axis=0), axis=0) | |
| def clip_process(self, x, bypass=False): | |
| if len(self.buffer) == self.n_frames or bypass: | |
| if self.accumulate: | |
| self.accumulate_buffer.append(self.preprocess(x)[0]) | |
| return torch.zeros(self.viclip_emb_dim) | |
| with torch.no_grad(): | |
| B, n_frames, C, H, W = x.shape | |
| obs = torch.from_numpy(x.copy().reshape(B * n_frames, C, H, W)).to(self.viclip_model.device) | |
| processed_obs = self.viclip_model.preprocess_transf(obs / 255) | |
| reshaped_obs = processed_obs.reshape(B, n_frames, 3,processed_obs.shape[-2],processed_obs.shape[-1]) | |
| video_embed = self.viclip_model.get_vid_features(reshaped_obs) | |
| return video_embed.detach() | |
| else: | |
| return torch.zeros(self.viclip_emb_dim) | |
| def step(self, action): | |
| ts, obs = self._env.step(action) | |
| self.buffer.append(self.hd_render(obs)) | |
| obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy() | |
| return ts, obs | |
| def reset(self,): | |
| # Important to reset the buffer | |
| self.buffer = deque(maxlen=self.n_frames) | |
| ts, obs = self._env.reset() | |
| self.buffer.append(self.hd_render(obs)) | |
| obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy() | |
| return ts, obs | |
| def __getattr__(self, name): | |
| if name == 'obs_space': | |
| space = self._env.obs_space | |
| space['clip_video'] = gym.spaces.Box(-np.inf, np.inf, (self.viclip_emb_dim,), dtype=np.float32) | |
| return space | |
| return getattr(self._env, name) | |
| class TimeLimit: | |
| def __init__(self, env, duration): | |
| self._env = env | |
| self._duration = duration | |
| self._step = None | |
| def __getattr__(self, name): | |
| if name.startswith('__'): | |
| raise AttributeError(name) | |
| return getattr(self._env, name) | |
| def step(self, action): | |
| assert self._step is not None, 'Must reset environment.' | |
| ts, obs = self._env.step(action) | |
| self._step += 1 | |
| if self._duration and self._step >= self._duration: | |
| ts = dm_env.TimeStep(dm_env.StepType.LAST, ts.reward, ts.discount, ts.observation) | |
| obs['is_last'] = True | |
| self._step = None | |
| return ts, obs | |
| def reset(self): | |
| self._step = 0 | |
| return self._env.reset() | |
| def reset_with_task_id(self, task_id): | |
| self._step = 0 | |
| return self._env.reset_with_task_id(task_id) | |
| class ClipActionWrapper: | |
| def __init__(self, env, low=-1.0, high=1.0): | |
| self._env = env | |
| self._low = low | |
| self._high = high | |
| def __getattr__(self, name): | |
| if name.startswith('__'): | |
| raise AttributeError(name) | |
| return getattr(self._env, name) | |
| def step(self, action): | |
| clipped_action = np.clip(action, self._low, self._high) | |
| return self._env.step(clipped_action) | |
| def reset(self): | |
| self._step = 0 | |
| return self._env.reset() | |
| def reset_with_task_id(self, task_id): | |
| self._step = 0 | |
| return self._env.reset_with_task_id(task_id) | |
| class NormalizeAction: | |
| def __init__(self, env, key='action'): | |
| self._env = env | |
| self._key = key | |
| space = env.act_space[key] | |
| self._mask = np.isfinite(space.low) & np.isfinite(space.high) | |
| self._low = np.where(self._mask, space.low, -1) | |
| self._high = np.where(self._mask, space.high, 1) | |
| def __getattr__(self, name): | |
| if name.startswith('__'): | |
| raise AttributeError(name) | |
| try: | |
| return getattr(self._env, name) | |
| except AttributeError: | |
| raise ValueError(name) | |
| def act_space(self): | |
| low = np.where(self._mask, -np.ones_like(self._low), self._low) | |
| high = np.where(self._mask, np.ones_like(self._low), self._high) | |
| space = gym.spaces.Box(low, high, dtype=np.float32) | |
| return {**self._env.act_space, self._key: space} | |
| def step(self, action): | |
| orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low | |
| orig = np.where(self._mask, orig, action[self._key]) | |
| return self._env.step({**action, self._key: orig}) | |
| def _make_jaco(obs_type, domain, task, action_repeat, seed, img_size,): | |
| import envs.custom_dmc_tasks as cdmc | |
| env = cdmc.make_jaco(task, obs_type, seed, img_size,) | |
| env = ActionDTypeWrapper(env, np.float32) | |
| env = ActionRepeatWrapper(env, action_repeat) | |
| env = FlattenJacoObservationWrapper(env) | |
| env._size = (img_size, img_size) | |
| return env | |
| def _make_dmc(obs_type, domain, task, action_repeat, seed, img_size,): | |
| visualize_reward = False | |
| from dm_control import manipulation, suite | |
| import envs.custom_dmc_tasks as cdmc | |
| if (domain, task) in suite.ALL_TASKS: | |
| env = suite.load(domain, | |
| task, | |
| task_kwargs=dict(random=seed), | |
| environment_kwargs=dict(flat_observation=True), | |
| visualize_reward=visualize_reward) | |
| else: | |
| env = cdmc.make(domain, | |
| task, | |
| task_kwargs=dict(random=seed), | |
| environment_kwargs=dict(flat_observation=True), | |
| visualize_reward=visualize_reward) | |
| env = ActionDTypeWrapper(env, np.float32) | |
| env = ActionRepeatWrapper(env, action_repeat) | |
| if obs_type == 'pixels': | |
| from dm_control.suite.wrappers import pixels | |
| # zoom in camera for quadruped | |
| camera_id = dict(locom_rodent=1,quadruped=2).get(domain, 0) | |
| render_kwargs = dict(height=img_size, width=img_size, camera_id=camera_id) | |
| env = pixels.Wrapper(env, | |
| pixels_only=True, | |
| render_kwargs=render_kwargs) | |
| env._size = (img_size, img_size) | |
| env._camera = camera_id | |
| return env | |
| def make(name, obs_type, action_repeat, seed, img_size=64, viclip_encode=False, clip_hd_rendering=False, device='cuda'): | |
| assert obs_type in ['states', 'pixels'] | |
| domain, task = name.split('_', 1) | |
| if domain == 'kitchen': | |
| env = TimeLimit(KitchenWrapper(task, seed=seed, action_repeat=action_repeat, size=(img_size,img_size)), 280 // action_repeat) | |
| else: | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
| os.environ['MUJOCO_GL'] = 'egl' | |
| domain = dict(cup='ball_in_cup', point='point_mass').get(domain, domain) | |
| make_fn = _make_jaco if domain == 'jaco' else _make_dmc | |
| env = make_fn(obs_type, domain, task, action_repeat, seed, img_size,) | |
| if obs_type == 'pixels': | |
| env = FramesWrapper(env,) | |
| else: | |
| env = ObservationDTypeWrapper(env, np.float32) | |
| from dm_control.suite.wrappers import action_scale | |
| env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) | |
| env = ExtendedTimeStepWrapper(env) | |
| env = DMC(env) | |
| env._domain_name = domain | |
| if isinstance(env.act_space['action'], gym.spaces.Box): | |
| env = ClipActionWrapper(env,) | |
| if viclip_encode: | |
| env = ViClipWrapper(env, hd_rendering=clip_hd_rendering, device=device) | |
| return env | |