gomoku / LightZero /zoo /classic_control /pendulum /envs /pendulum_lightzero_env.py
zjowowen's picture
init space
079c32c
raw
history blame
7.51 kB
import copy
from datetime import datetime
from typing import Union, Dict
import gymnasium as gym
import numpy as np
from ding.envs import BaseEnvTimestep
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
from easydict import EasyDict
from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv
@ENV_REGISTRY.register('pendulum_lightzero')
class PendulumEnv(CartPoleEnv):
"""
LightZero version of the classic Pendulum environment. This class includes methods for resetting, closing, and
stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random
actions. It also includes properties for accessing the observation space, action space, and reward space of the
environment.
"""
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
# (bool) Whether to use continuous action space
continuous=True,
# replay_path (str or None): The path to save the replay video. If None, the replay will not be saved.
# Only effective when env_manager.type is 'base'.
replay_path=None,
# (bool) Whether to scale action into [-2, 2]
act_scale=True,
)
def __init__(self, cfg: dict) -> None:
"""
Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards.
"""
self._cfg = cfg
self._act_scale = cfg.act_scale
try:
self._env = gym.make('Pendulum-v1', render_mode="rgb_array")
except:
self._env = gym.make('Pendulum-v0', render_mode="rgb_array")
self._init_flag = False
self._replay_path = cfg.replay_path
self._continuous = cfg.get("continuous", True)
self._observation_space = gym.spaces.Box(
low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3,), dtype=np.float32
)
if self._continuous:
self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float32)
else:
self.discrete_action_num = 11
self._action_space = gym.spaces.Discrete(self.discrete_action_num)
self._action_space.seed(0) # default seed
self._reward_space = gym.spaces.Box(
low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1,), dtype=np.float32
)
def reset(self) -> Dict[str, np.ndarray]:
"""
Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding
if necessary. Returns the first observation.
"""
if not self._init_flag:
try:
self._env = gym.make('Pendulum-v1', render_mode="rgb_array")
except:
self._env = gym.make('Pendulum-v0', render_mode="rgb_array")
if self._replay_path is not None:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
video_name = f'{self._env.spec.id}-video-{timestamp}'
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix=video_name
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._seed = self._seed + np_seed
self._action_space.seed(self._seed)
obs, _ = self._env.reset(seed=self._seed)
elif hasattr(self, '_seed'):
self._action_space.seed(self._seed)
obs, _ = self._env.reset(seed=self._seed)
else:
obs, _ = self._env.reset()
obs = to_ndarray(obs).astype(np.float32)
self._eval_episode_return = 0.
if not self._continuous:
action_mask = np.ones(self.discrete_action_num, 'int8')
else:
action_mask = None
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
return obs
def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
"""
Overview:
Step the environment forward with the provided action. This method returns the next state of the environment
(observation, reward, done flag, and info dictionary) encapsulated in a BaseEnvTimestep object.
Arguments:
- action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment.
Returns:
- timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag,
and info dictionary.
.. note::
- If the environment requires discrete actions, they are converted to float actions in the range [-1, 1].
- If action scaling is enabled, continuous actions are scaled into the range [-2, 2].
- For each step, the cumulative reward (`_eval_episode_return`) is updated.
- If the episode ends (done is True), the total reward for the episode is stored in the info dictionary
under the key 'eval_episode_return'.
- If the environment requires discrete actions, an action mask is created, otherwise, it's None.
- Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'.
"""
if isinstance(action, int):
action = np.array(action)
# if require discrete env, convert actions to [-1 ~ 1] float actions
if not self._continuous:
action = (action / (self.discrete_action_num - 1)) * 2 - 1
# scale the continous action into [-2, 2]
if self._act_scale:
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
obs, rew, terminated, truncated, info = self._env.step(action)
done = terminated or truncated
self._eval_episode_return += rew
obs = to_ndarray(obs).astype(np.float32)
# wrapped to be transferred to an array with shape (1,)
rew = to_ndarray([rew]).astype(np.float32)
if done:
info['eval_episode_return'] = self._eval_episode_return
if not self._continuous:
action_mask = np.ones(self.discrete_action_num, 'int8')
else:
action_mask = None
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
return BaseEnvTimestep(obs, rew, done, info)
def random_action(self) -> np.ndarray:
"""
Generate a random action using the action space's sample method. Returns a numpy array containing the action.
"""
if self._continuous:
random_action = self.action_space.sample().astype(np.float32)
else:
random_action = self.action_space.sample()
random_action = to_ndarray([random_action], dtype=np.int64)
return random_action
def __repr__(self) -> str:
"""
String representation of the environment.
"""
return "LightZero Pendulum Env({})".format(self._cfg.env_id)