import numpy as np from gym import spaces, Wrapper class FilterWrapper(Wrapper): """ :param env: (gym.Env) Gym environment that will be wrapped """ def __init__(self, env): self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds self.blue_deads = np.full((self.nb_blues,), False) self.red_deads = np.full((self.nb_reds,), False) env.observation_space = spaces.Tuple(( spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32), spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32), spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32), spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32), spaces.Discrete(1), spaces.Discrete(1))) super(FilterWrapper, self).__init__(env) def reset(self): """ Reset the environment """ obs = self.env.reset() return self._sort_obs(obs) def step(self, action): """ :param action: ([float] or int) Action taken by the agent :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations """ blue_action, red_action = action new_ba = [] index = 0 for count, alive in enumerate(~self.blue_deads): if alive: new_ba.append(blue_action[index]) index += 1 else: new_ba.append(np.array([0, 0, 0])) blue_action = new_ba new_ra = [] index = 0 for count, alive in enumerate(~self.red_deads): if alive: new_ra.append(red_action[index]) index += 1 else: new_ra.append(np.array([0, 0, 0])) red_action = new_ra action = blue_action, red_action obs, reward, done, info = self.env.step(action) obs = self._sort_obs(obs) return obs, reward, done, info def _sort_obs(self, obs): blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs self.blue_deads = blue_deads self.red_deads = red_deads blue_obs = np.vstack((blue_obs[~self.blue_deads], blue_obs[self.blue_deads])) red_obs = np.vstack((red_obs[~self.red_deads], red_obs[self.red_deads])) blues_fire = self.fire_sort(self.blue_deads, self.red_deads, blues_fire) reds_fire = self.fire_sort(self.red_deads, self.blue_deads, reds_fire) sort_obs = blue_obs, red_obs, blues_fire, reds_fire, sum(blue_deads), sum(red_deads) return sort_obs def fire_sort(self, dead_friends, dead_foes, friends_fire): friends_fire_big = np.zeros_like(friends_fire) friends_fire = np.compress(~dead_friends, friends_fire, axis=0) friends_fire = np.compress(~dead_foes, friends_fire, axis=1) friends_fire_big[:friends_fire.shape[0], :friends_fire.shape[1]] = friends_fire return friends_fire_big