vpg-CartPole-v1 / wrappers /atari_wrappers.py
sgoodfriend's picture
VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
raw
history blame
2.95 kB
import gym
import numpy as np
from typing import Any, Dict, Tuple, Union
ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env: gym.Env, training: bool = True, noop_act: int = 0) -> None:
super().__init__(env)
self.training = training
self.noop_act = noop_act
self.life_done_continue = False
self.lives = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
new_lives = self.env.unwrapped.ale.lives()
self.life_done_continue = new_lives != self.lives and not done
# Only if training should life-end be marked as done
if self.training and 0 < new_lives < self.lives:
done = True
self.lives = new_lives
return obs, rew, done, info
def reset(self, **kwargs) -> ObsType:
# If life_done_continue (but not game over), then a reset should just allow the
# game to progress to the next life.
if self.training and self.life_done_continue:
obs, _, _, _ = self.env.step(self.noop_act)
else:
obs = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives()
return obs
class FireOnLifeStarttEnv(gym.Wrapper):
def __init__(self, env: gym.Env, fire_act: int = 1) -> None:
super().__init__(env)
self.fire_act = fire_act
action_meanings = env.unwrapped.get_action_meanings()
assert action_meanings[fire_act] == "FIRE"
assert len(action_meanings) >= 3
self.lives = 0
self.fire_on_next_step = True
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
if self.fire_on_next_step:
action = self.fire_act
self.fire_on_next_step = False
obs, rew, done, info = self.env.step(action)
new_lives = self.env.unwrapped.ale.lives()
if 0 < new_lives < self.lives and not done:
self.fire_on_next_step = True
self.lives = new_lives
return obs, rew, done, info
def reset(self, **kwargs) -> ObsType:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(self.fire_act)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
self.fire_on_next_step = False
return obs
class ClipRewardEnv(gym.Wrapper):
def __init__(self, env: gym.Env, training: bool = True) -> None:
super().__init__(env)
self.training = training
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
if self.training:
info["unclipped_reward"] = rew
rew = np.sign(rew)
return obs, rew, done, info