swarms / team_wrap.py
YvesP's picture
initial load
a162e39
import numpy as np
import gym
from gym import spaces
from swarm_policy import SwarmPolicy
from settings import Settings
class TeamWrapper(gym.Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
"""
def __init__(self, env, is_blue: bool = True, is_double: bool = False, is_unkillable: bool = Settings.is_unkillable):
self.is_blue = is_blue
self.is_double = is_double
self.is_unkillabe = is_unkillable
nb_blues, nb_reds = env.nb_blues, env.nb_reds
self.foe_action = None
self.foe_policy = SwarmPolicy(is_blue=not is_blue, blues=nb_blues, reds=nb_reds)
if is_double:
env.action_space = spaces.Tuple((
spaces.Box(low=-1, high=1, shape=(nb_blues*3,), dtype=np.float32),
spaces.Box(low=-1, high=1, shape=(nb_reds*3,), dtype=np.float32)
))
else:
nb_friends = nb_blues if is_blue else nb_reds
env.action_space = spaces.Box(low=-1, high=1, shape=(nb_friends*3,), dtype=np.float32)
flatten_dimension = 6 * nb_blues + 6 * nb_reds # the position and speeds for blue and red drones
flatten_dimension += (nb_blues * nb_reds) * (1 if is_unkillable else 2) # the fire matrices
env.observation_space = spaces.Box(low=-1, high=1, shape=(flatten_dimension,), dtype=np.float32)
super(TeamWrapper, self).__init__(env)
def reset(self):
"""
Reset the environment
"""
obs = self.env.reset()
obs = self.post_obs(obs)
return obs
def step(self, action):
"""
:param action: ([float] or int) Action taken by the agent
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
"""
if self.is_double:
blue_action, red_action = action
blue_action = _decentralise(blue_action)
red_action = _decentralise(red_action)
action = _unflatten(blue_action), _unflatten(red_action)
else:
friend_action = _decentralise(action)
foe_action = _decentralise(self.foe_action)
if self.is_blue:
action = _unflatten(friend_action), _unflatten(foe_action)
else:
action = _unflatten(foe_action), _unflatten(friend_action)
obs, reward, done, info = self.env.step(action)
obs = self.post_obs(obs)
return obs, reward, done, info
def post_obs(self, obs):
if self.is_unkillabe:
o1, o2, o3, _ = obs
obs = o1, o2, o3
flatten_obs = _flatten(obs)
centralised_obs = _centralise(flatten_obs)
if not self.is_double:
self.foe_action = self.foe_policy.predict(centralised_obs)
return centralised_obs
def _unflatten(action):
return np.split(action, len(action)/3)
def _flatten(obs): # need normalisation too
fl_obs = [this_obs.flatten().astype('float32') for this_obs in obs]
fl_obs = np.hstack(fl_obs)
return fl_obs
def _centralise(obs): # [0,1] to [-1,1]
obs = 2 * obs - 1
return obs
def _decentralise(act): # [-1,1] to [0,1]
act = 0.5 * (act + 1)
return act