swarms / distribution_wrap.py
YvesP's picture
initial load
a162e39
import gym
from gym import spaces
import numpy as np
from runner import run_episode
from redux_wrap import ReduxWrapper
from rotate_wrap import RotateWrapper
from symetry_wrap import SymetryWrapper
from sort_wrap import SortWrapper
from team_wrap import TeamWrapper
from reward_wrap import RewardWrapper
class DistriWrapper(gym.Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
"""
def __init__(self, env):
self.blue_deads = self.red_deads = 0
self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds
env.observation_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32)))
env.action_space = spaces.Tuple((
spaces.Box(low=0, high=1, shape=(self.nb_blues, 3), dtype=np.float32),
spaces.Box(low=0, high=1, shape=(self.nb_reds, 3), dtype=np.float32)))
# Call the parent constructor, so we can access self.env later
super(DistriWrapper, self).__init__(env)
def reset(self):
"""
Reset the environment
"""
obs = self.env.reset()
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs
self.blue_deads, self.blue_deads = blue_deads, red_deads
return blue_obs, red_obs, blues_fire, reds_fire
def step(self, action):
"""
:param action: ([float] or int) Action taken by the agent
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
"""
obs, reward, done, info = self.env.step(action)
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs
obs = blue_obs, red_obs, blues_fire, reds_fire
if done: # environment decision (eg drones oob)
return obs, reward, True, info
if red_deads == len(red_obs): # no more reds to fight (it could mean that they have all reached the target)
return obs, reward, True, info
if blue_deads == len(blue_obs): # reds have won
return obs, reward, True, info
# do we have new deaths?
new_blue_deads = blue_deads - self.blue_deads
new_red_deads = red_deads - self.red_deads
self.blue_deads, self.red_deads = blue_deads, red_deads
if 0 < new_red_deads + new_blue_deads: # we have someone killed but we still have some fight
blues, reds = self.nb_blues - blue_deads, self.nb_reds - red_deads
env = ReduxWrapper(self, minus_blue=blue_deads, minus_red=red_deads)
obs_ = env.post_obs(obs)
env = RotateWrapper(env)
obs_ = env.post_obs(obs_)
env = SymetryWrapper(env)
obs_ = env.post_obs(obs_)
env = SortWrapper(env)
obs_ = env.post_obs(obs_)
env = RewardWrapper(TeamWrapper(env, is_double=True), is_double=True)
obs_ = env.post_obs(obs_)
_, reward, done, info = run_episode(env, obs_, blues=blues, reds=reds)
return obs, reward, done, info