swarms / redux_wrap.py
YvesP's picture
initial load
a162e39
import gym
from gym import spaces
import numpy as np
from settings import Settings
class ReduxWrapper(gym.Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
"""
def __init__(self, env, minus_blue=0, minus_red=0):
# action space is reduced
nb_blues, nb_reds = Settings.blues, Settings.reds
self.nb_blues = nb_blues - minus_blue
self.nb_reds = nb_reds - minus_red
self.blue_deads = minus_blue
self.red_deads = minus_red
env.observation_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32)))
env.action_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(self.nb_blues, 3), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, 3), dtype=np.float32)))
super(ReduxWrapper, self).__init__(env)
def reset(self):
"""
Reset the environment
"""
obs = self.env.reset()
obs = self.post_obs(obs)
return obs
def step(self, action):
# action needs expansion
blue_action, red_action = action
if self.blue_deads:
blue_action = np.vstack((blue_action, np.zeros((self.blue_deads, 3))))
if self.red_deads:
red_action = np.vstack((red_action, np.zeros((self.red_deads, 3))))
action = blue_action, red_action
obs, reward, done, info = self.env.step(action)
obs = self.post_obs(obs)
return obs, reward, done, info
def post_obs(self, obs):
# obs needs reduction
blue_obs, red_obs, blues_fire, reds_fire = obs
if not self.blue_deads:
pass
else:
blue_obs = blue_obs[:-self.blue_deads]
blues_fire = blues_fire[:-self.blue_deads]
reds_fire = reds_fire[:, :-self.blue_deads]
if not self.red_deads:
pass
else:
red_obs = red_obs[:-self.red_deads]
reds_fire = reds_fire[:-self.red_deads]
blues_fire = blues_fire[:, :-self.red_deads]
return blue_obs, red_obs, blues_fire, reds_fire