swarms / symetry_wrap.py
YvesP's picture
initial load
a162e39
import numpy as np
import gym
class SymetryWrapper(gym.Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
"""
def __init__(self, env):
# Call the parent constructor, so we can access self.env later
self.symetry = False # no need to perform a symetry
super(SymetryWrapper, self).__init__(env)
def reset(self):
"""
Reset the environment
"""
obs = self.env.reset()
obs = self.post_obs(obs)
return obs
def step(self, action):
"""
:param action: ([float] or int) Action taken by the agent
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
"""
if self.symetry:
action = symetrise_action(action)
obs, reward, done, info = self.env.step(action)
obs = self.post_obs(obs)
return obs, reward, done, info
def post_obs(self, obs):
self.symetry = get_symetry(obs)
if self.symetry:
obs = symetrise_obs(obs)
return obs
def get_symetry(obs):
blue_obs, red_obs, blue_fire, red_fire = obs
# count the drones who are positioned above the 0 x-axis
count = 0
for this_obs in (blue_obs, red_obs):
for d in this_obs:
add = 1 if (d[1] < 0.5) else 0
count += add
# compare with the total
symetry = bool(2*count < (len(blue_obs) + len(red_obs)))
return symetry
def symetrise_obs(obs):
blue_obs, red_obs, blue_fire, red_fire = obs
for this_obs in (blue_obs, red_obs):
# symetrise positions and speeds
this_obs[:, 1] = 1 - this_obs[:, 1]
this_obs[:, 4] = 1 - this_obs[:, 4]
return blue_obs, red_obs, blue_fire, red_fire
def symetrise_action(action):
blue_action, red_action = action
for this_action in (blue_action, red_action):
for act in this_action:
# symetrise action
act[1] = - act[1]
action = blue_action, red_action
return action
def test_symetrise_obs():
obs = np.arange(12).reshape(2, 6), np.arange(12).reshape(2, 6), np.random.random((1, 1)), np.random.random((1, 1))
print(obs)
symetrise_obs(obs)
print(obs)