Text-Gym-Agents / envs /atari /represented_atari_game.py
hyyh28
update atari_env
1aa4792
import gym
import ale_py
import numpy as np
from atariari.benchmark.wrapper import AtariARIWrapper
from typing import Optional, Union
class RepresentedAtariEnv(gym.Wrapper):
def __init__(self, env_name, render_mode=None):
super().__init__(AtariARIWrapper(gym.make(env_name, render_mode=render_mode)))
self.metadata = self.env.metadata
self.env_name = env_name
self.observation = None
self.info = {}
self.action_space = self.env.action_space
_ = self.env.reset()
obs = self.env.labels()
obs_dim = len(obs)
self.obs_label = obs.keys()
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
def step(self, action):
original_next_obs, reward, env_done, env_truncated, info = self.env.step(action)
next_obs = self.env.labels()
self.obs_label = next_obs.keys()
self.observation = next_obs
return np.array(list(next_obs.values())), reward, env_done, env_truncated, info
def reset(self, seed=0):
obs_original, info = self.env.reset(seed=seed)
obs = self.env.labels()
self.obs_label = obs.keys()
self.observation = obs
return np.array(list(obs.values())), info
def get_info(self):
return self.observation
def render(self, render_mode=None):
return self.env.render()
class RepresentedMsPacman(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "MsPacmanNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedBowling(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "BowlingNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedBoxing(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "BoxingNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedBreakout(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "BreakoutNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedDemonAttack(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "DemonAttackNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedFreeway(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "FreewayNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedFrostbite(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "FrostbiteNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedHero(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "HeroNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedMontezumaRevenge(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "MontezumaRevengeNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedPitfall(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "PitfallNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedPong(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "PongNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedPrivateEye(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "PrivateEyeNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedQbert(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "QbertNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedRiverraid(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "RiverraidNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedSeaquest(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "SeaquestNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedSpaceInvaders(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "SpaceInvadersNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedTennis(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "TennisNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedVenture(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "VentureNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
class RepresentedVideoPinball(RepresentedAtariEnv):
def __init__(self, render_mode: Optional[str]=None):
env_name = "VideoPinballNoFrameskip-v4"
super().__init__(env_name=env_name, render_mode=render_mode)
def env_factory(env_class):
def _create_instance(render_mode=None):
return env_class(render_mode=render_mode)
return _create_instance
def register_environments():
env_classes = {
'RepresentedMsPacman-v0': RepresentedMsPacman,
'RepresentedBowling-v0': RepresentedBowling,
'RepresentedBoxing-v0': RepresentedBoxing,
'RepresentedBreakout-v0': RepresentedBreakout,
'RepresentedDemonAttack-v0': RepresentedDemonAttack,
'RepresentedFreeway-v0': RepresentedFreeway,
'RepresentedFrostbite-v0': RepresentedFrostbite,
'RepresentedHero-v0': RepresentedHero,
'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
'RepresentedPitfall-v0': RepresentedPitfall,
'RepresentedPong-v0': RepresentedPong,
'RepresentedPrivateEye-v0': RepresentedPrivateEye,
'RepresentedQbert-v0': RepresentedQbert,
'RepresentedRiverraid-v0': RepresentedRiverraid,
'RepresentedSeaquest-v0': RepresentedSeaquest,
'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
'RepresentedTennis-v0': RepresentedTennis,
'RepresentedVenture-v0': RepresentedVenture,
'RepresentedVideoPinball-v0': RepresentedVideoPinball
}
for env_name, env_class in env_classes.items():
gym.register(
id=env_name,
entry_point=env_factory(env_class),
)
# register_environments()
# env_classes = {
# 'RepresentedMsPacman-v0': RepresentedMsPacman,
# 'RepresentedBowling-v0': RepresentedBowling,
# 'RepresentedBoxing-v0': RepresentedBoxing,
# 'RepresentedBreakout-v0': RepresentedBreakout,
# 'RepresentedDemonAttack-v0': RepresentedDemonAttack,
# 'RepresentedFreeway-v0': RepresentedFreeway,
# 'RepresentedFrostbite-v0': RepresentedFrostbite,
# 'RepresentedHero-v0': RepresentedHero,
# 'RepresentedMontezumaRevenge-v0': RepresentedMontezumaRevenge,
# 'RepresentedPitfall-v0': RepresentedPitfall,
# 'RepresentedPong-v0': RepresentedPong,
# 'RepresentedPrivateEye-v0': RepresentedPrivateEye,
# 'RepresentedQbert-v0': RepresentedQbert,
# 'RepresentedRiverraid-v0': RepresentedRiverraid,
# 'RepresentedSeaquest-v0': RepresentedSeaquest,
# 'RepresentedSpaceInvaders-v0': RepresentedSpaceInvaders,
# 'RepresentedTennis-v0': RepresentedTennis,
# 'RepresentedVenture-v0': RepresentedVenture,
# 'RepresentedVideoPinball-v0': RepresentedVideoPinball
# }
#
# for env, env_class in env_classes.items():
# env_1 = env_class()
# env_name = env_1.env_name
# env_2 = gym.make(env_name)
# print(env_name, env_1.action_space == env_2.action_space, env_1.action_space)