gomoku / LightZero /zoo /atari /envs /atari_wrappers.py
zjowowen's picture
init space
079c32c
raw
history blame
10.7 kB
# Adapted from openai baselines: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
from datetime import datetime
from typing import Optional
import cv2
import gymnasium
import gym
import numpy as np
from ding.envs import NoopResetWrapper, MaxAndSkipWrapper, EpisodicLifeWrapper, FireResetWrapper, WarpFrameWrapper, \
ScaledFloatFrameWrapper, \
ClipRewardWrapper, FrameStackWrapper
from ding.utils.compression_helper import jpeg_data_compressor
from easydict import EasyDict
from gymnasium.wrappers import RecordVideo
# only for reference now
def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True):
"""Configure environment for DeepMind-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
:param bool scale: wrap the scaling observation wrapper.
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
assert 'NoFrameskip' in env_id
env = gym.make(env_id)
env = NoopResetWrapper(env, noop_max=30)
env = MaxAndSkipWrapper(env, skip=4)
if episode_life:
env = EpisodicLifeWrapper(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetWrapper(env)
if warp_frame:
env = WarpFrameWrapper(env)
if scale:
env = ScaledFloatFrameWrapper(env)
if clip_rewards:
env = ClipRewardWrapper(env)
if frame_stack:
env = FrameStackWrapper(env, frame_stack)
return env
# only for reference now
def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True):
"""Configure environment for DeepMind-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
:param bool scale: wrap the scaling observation wrapper.
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
assert 'MontezumaRevenge' in env_id
env = gym.make(env_id)
env = NoopResetWrapper(env, noop_max=30)
env = MaxAndSkipWrapper(env, skip=4)
if episode_life:
env = EpisodicLifeWrapper(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetWrapper(env)
if warp_frame:
env = WarpFrameWrapper(env)
if scale:
env = ScaledFloatFrameWrapper(env)
if clip_rewards:
env = ClipRewardWrapper(env)
if frame_stack:
env = FrameStackWrapper(env, frame_stack)
return env
def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> gym.Env:
"""
Overview:
Configure environment for MuZero-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
Arguments:
- config (:obj:`Dict`): Dict containing configuration parameters for the environment.
- episode_life (:obj:`bool`): If True, the agent starts with a set number of lives and loses them during the game.
- clip_rewards (:obj:`bool`): If True, the rewards are clipped to a certain range.
Return:
- env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations.
"""
if config.render_mode_human:
env = gymnasium.make(config.env_name, render_mode='human')
else:
env = gymnasium.make(config.env_name, render_mode='rgb_array')
assert 'NoFrameskip' in env.spec.id
if config.save_replay:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{env.spec.id}-video-{timestamp}'
env = RecordVideo(
env,
video_folder=config.replay_path,
episode_trigger=lambda episode_id: True,
name_prefix=video_name
)
env = GymnasiumToGymWrapper(env)
env = NoopResetWrapper(env, noop_max=30)
env = MaxAndSkipWrapper(env, skip=config.frame_skip)
if episode_life:
env = EpisodicLifeWrapper(env)
env = TimeLimit(env, max_episode_steps=config.max_episode_steps)
if config.warp_frame:
# we must set WarpFrame before ScaledFloatFrameWrapper
env = WarpFrame(env, width=config.obs_shape[1], height=config.obs_shape[2], grayscale=config.gray_scale)
if config.scale:
env = ScaledFloatFrameWrapper(env)
if clip_rewards:
env = ClipRewardWrapper(env)
env = JpegWrapper(env, transform2string=config.transform2string)
if config.game_wrapper:
env = GameWrapper(env)
return env
class TimeLimit(gym.Wrapper):
"""
Overview:
A wrapper that limits the maximum number of steps in an episode.
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
"""
Arguments:
- env (:obj:`gym.Env`): The environment to wrap.
- max_episode_steps (:obj:`Optional[int]`): Maximum number of steps per episode. If None, no limit is applied.
"""
super(TimeLimit, self).__init__(env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
def step(self, ac):
observation, reward, done, info = self.env.step(ac)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
done = True
info['TimeLimit.truncated'] = True
return observation, reward, done, info
def reset(self, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)
class WarpFrame(gym.ObservationWrapper):
"""
Overview:
A wrapper that warps frames to 84x84 as done in the Nature paper and later work.
"""
def __init__(self, env: gym.Env, width: int = 84, height: int = 84, grayscale: bool = True,
dict_space_key: Optional[str] = None):
"""
Arguments:
- env (:obj:`gym.Env`): The environment to wrap.
- width (:obj:`int`): The width to which the frames are resized.
- height (:obj:`int`): The height to which the frames are resized.
- grayscale (:obj:`bool`): If True, convert frames to grayscale.
- dict_space_key (:obj:`Optional[str]`): If specified, indicates which observation should be warped.
"""
super().__init__(env)
self._width = width
self._height = height
self._grayscale = grayscale
self._key = dict_space_key
if self._grayscale:
num_colors = 1
else:
num_colors = 3
new_space = gym.spaces.Box(
low=0,
high=255,
shape=(self._height, self._width, num_colors),
dtype=np.uint8,
)
if self._key is None:
original_space = self.observation_space
self.observation_space = new_space
else:
original_space = self.observation_space.spaces[self._key]
self.observation_space.spaces[self._key] = new_space
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key]
if self._grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self._width, self._height), interpolation=cv2.INTER_AREA)
if self._grayscale:
frame = np.expand_dims(frame, -1)
if self._key is None:
obs = frame
else:
obs = obs.copy()
obs[self._key] = frame
return obs
class JpegWrapper(gym.Wrapper):
"""
Overview:
A wrapper that converts the observation into a string to save memory.
"""
def __init__(self, env: gym.Env, transform2string: bool = True):
"""
Arguments:
- env (:obj:`gym.Env`): The environment to wrap.
- transform2string (:obj:`bool`): If True, transform the observations to string.
"""
super().__init__(env)
self.transform2string = transform2string
def step(self, action):
observation, reward, done, info = self.env.step(action)
if self.transform2string:
observation = jpeg_data_compressor(observation)
return observation, reward, done, info
def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
if self.transform2string:
observation = jpeg_data_compressor(observation)
return observation
class GameWrapper(gym.Wrapper):
"""
Overview:
A wrapper to adapt the environment to the game interface.
"""
def __init__(self, env: gym.Env):
"""
Arguments:
- env (:obj:`gym.Env`): The environment to wrap.
"""
super().__init__(env)
def legal_actions(self):
return [_ for _ in range(self.env.action_space.n)]
class GymnasiumToGymWrapper(gym.Wrapper):
"""
Overview:
A wrapper class that adapts a Gymnasium environment to the Gym interface.
Interface:
``__init__``, ``reset``, ``seed``
Properties:
- _seed (:obj:`int` or None): The seed value for the environment.
"""
def __init__(self, env):
"""
Overview:
Initializes the GymnasiumToGymWrapper.
Arguments:
- env (:obj:`gymnasium.Env`): The Gymnasium environment to be wrapped.
"""
assert isinstance(env, gymnasium.Env), type(env)
super().__init__(env)
self._seed = None
def seed(self, seed):
"""
Overview:
Sets the seed value for the environment.
Arguments:
- seed (:obj:`int`): The seed value to use for random number generation.
"""
self._seed = seed
def reset(self):
"""
Overview:
Resets the environment and returns the initial observation.
Returns:
- observation (:obj:`Any`): The initial observation of the environment.
"""
if self._seed is not None:
obs, _ = self.env.reset(seed=self._seed)
return obs
else:
obs, _ = self.env.reset()
return obs