|
|
|
import numpy as np |
|
from gym import spaces, Wrapper |
|
|
|
|
|
class FilterWrapper(Wrapper): |
|
""" |
|
:param env: (gym.Env) Gym environment that will be wrapped |
|
""" |
|
|
|
def __init__(self, env): |
|
|
|
self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds |
|
|
|
self.blue_deads = np.full((self.nb_blues,), False) |
|
self.red_deads = np.full((self.nb_reds,), False) |
|
|
|
env.observation_space = spaces.Tuple(( |
|
spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32), |
|
spaces.Discrete(1), |
|
spaces.Discrete(1))) |
|
|
|
super(FilterWrapper, self).__init__(env) |
|
|
|
def reset(self): |
|
""" |
|
Reset the environment |
|
""" |
|
obs = self.env.reset() |
|
|
|
return self._sort_obs(obs) |
|
|
|
def step(self, action): |
|
""" |
|
:param action: ([float] or int) Action taken by the agent |
|
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations |
|
""" |
|
|
|
blue_action, red_action = action |
|
|
|
new_ba = [] |
|
index = 0 |
|
for count, alive in enumerate(~self.blue_deads): |
|
if alive: |
|
new_ba.append(blue_action[index]) |
|
index += 1 |
|
else: |
|
new_ba.append(np.array([0, 0, 0])) |
|
blue_action = new_ba |
|
|
|
new_ra = [] |
|
index = 0 |
|
for count, alive in enumerate(~self.red_deads): |
|
if alive: |
|
new_ra.append(red_action[index]) |
|
index += 1 |
|
else: |
|
new_ra.append(np.array([0, 0, 0])) |
|
red_action = new_ra |
|
|
|
action = blue_action, red_action |
|
|
|
obs, reward, done, info = self.env.step(action) |
|
|
|
obs = self._sort_obs(obs) |
|
|
|
return obs, reward, done, info |
|
|
|
def _sort_obs(self, obs): |
|
|
|
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs |
|
|
|
self.blue_deads = blue_deads |
|
self.red_deads = red_deads |
|
|
|
blue_obs = np.vstack((blue_obs[~self.blue_deads], blue_obs[self.blue_deads])) |
|
red_obs = np.vstack((red_obs[~self.red_deads], red_obs[self.red_deads])) |
|
|
|
blues_fire = self.fire_sort(self.blue_deads, self.red_deads, blues_fire) |
|
reds_fire = self.fire_sort(self.red_deads, self.blue_deads, reds_fire) |
|
|
|
sort_obs = blue_obs, red_obs, blues_fire, reds_fire, sum(blue_deads), sum(red_deads) |
|
|
|
return sort_obs |
|
|
|
def fire_sort(self, dead_friends, dead_foes, friends_fire): |
|
|
|
friends_fire_big = np.zeros_like(friends_fire) |
|
friends_fire = np.compress(~dead_friends, friends_fire, axis=0) |
|
friends_fire = np.compress(~dead_foes, friends_fire, axis=1) |
|
friends_fire_big[:friends_fire.shape[0], :friends_fire.shape[1]] = friends_fire |
|
return friends_fire_big |
|
|