swarms / filter_wrap.py
YvesP's picture
initial load
a162e39
import numpy as np
from gym import spaces, Wrapper
class FilterWrapper(Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
"""
def __init__(self, env):
self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds
self.blue_deads = np.full((self.nb_blues,), False)
self.red_deads = np.full((self.nb_reds,), False)
env.observation_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32),
spaces.Discrete(1),
spaces.Discrete(1)))
super(FilterWrapper, self).__init__(env)
def reset(self):
"""
Reset the environment
"""
obs = self.env.reset()
return self._sort_obs(obs)
def step(self, action):
"""
:param action: ([float] or int) Action taken by the agent
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
"""
blue_action, red_action = action
new_ba = []
index = 0
for count, alive in enumerate(~self.blue_deads):
if alive:
new_ba.append(blue_action[index])
index += 1
else:
new_ba.append(np.array([0, 0, 0]))
blue_action = new_ba
new_ra = []
index = 0
for count, alive in enumerate(~self.red_deads):
if alive:
new_ra.append(red_action[index])
index += 1
else:
new_ra.append(np.array([0, 0, 0]))
red_action = new_ra
action = blue_action, red_action
obs, reward, done, info = self.env.step(action)
obs = self._sort_obs(obs)
return obs, reward, done, info
def _sort_obs(self, obs):
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs
self.blue_deads = blue_deads
self.red_deads = red_deads
blue_obs = np.vstack((blue_obs[~self.blue_deads], blue_obs[self.blue_deads]))
red_obs = np.vstack((red_obs[~self.red_deads], red_obs[self.red_deads]))
blues_fire = self.fire_sort(self.blue_deads, self.red_deads, blues_fire)
reds_fire = self.fire_sort(self.red_deads, self.blue_deads, reds_fire)
sort_obs = blue_obs, red_obs, blues_fire, reds_fire, sum(blue_deads), sum(red_deads)
return sort_obs
def fire_sort(self, dead_friends, dead_foes, friends_fire):
friends_fire_big = np.zeros_like(friends_fire)
friends_fire = np.compress(~dead_friends, friends_fire, axis=0)
friends_fire = np.compress(~dead_foes, friends_fire, axis=1)
friends_fire_big[:friends_fire.shape[0], :friends_fire.shape[1]] = friends_fire
return friends_fire_big