|
import numpy as np |
|
import gym |
|
|
|
|
|
class SymetryWrapper(gym.Wrapper): |
|
""" |
|
:param env: (gym.Env) Gym environment that will be wrapped |
|
""" |
|
|
|
def __init__(self, env): |
|
|
|
|
|
self.symetry = False |
|
super(SymetryWrapper, self).__init__(env) |
|
|
|
def reset(self): |
|
""" |
|
Reset the environment |
|
""" |
|
obs = self.env.reset() |
|
|
|
obs = self.post_obs(obs) |
|
|
|
return obs |
|
|
|
def step(self, action): |
|
""" |
|
:param action: ([float] or int) Action taken by the agent |
|
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations |
|
""" |
|
if self.symetry: |
|
action = symetrise_action(action) |
|
|
|
obs, reward, done, info = self.env.step(action) |
|
|
|
obs = self.post_obs(obs) |
|
|
|
return obs, reward, done, info |
|
|
|
def post_obs(self, obs): |
|
self.symetry = get_symetry(obs) |
|
if self.symetry: |
|
obs = symetrise_obs(obs) |
|
return obs |
|
|
|
|
|
def get_symetry(obs): |
|
blue_obs, red_obs, blue_fire, red_fire = obs |
|
|
|
|
|
count = 0 |
|
for this_obs in (blue_obs, red_obs): |
|
for d in this_obs: |
|
add = 1 if (d[1] < 0.5) else 0 |
|
count += add |
|
|
|
|
|
symetry = bool(2*count < (len(blue_obs) + len(red_obs))) |
|
|
|
return symetry |
|
|
|
|
|
def symetrise_obs(obs): |
|
|
|
blue_obs, red_obs, blue_fire, red_fire = obs |
|
|
|
for this_obs in (blue_obs, red_obs): |
|
|
|
this_obs[:, 1] = 1 - this_obs[:, 1] |
|
this_obs[:, 4] = 1 - this_obs[:, 4] |
|
|
|
return blue_obs, red_obs, blue_fire, red_fire |
|
|
|
|
|
def symetrise_action(action): |
|
|
|
blue_action, red_action = action |
|
|
|
for this_action in (blue_action, red_action): |
|
for act in this_action: |
|
|
|
|
|
act[1] = - act[1] |
|
|
|
action = blue_action, red_action |
|
|
|
return action |
|
|
|
|
|
def test_symetrise_obs(): |
|
|
|
obs = np.arange(12).reshape(2, 6), np.arange(12).reshape(2, 6), np.random.random((1, 1)), np.random.random((1, 1)) |
|
print(obs) |
|
symetrise_obs(obs) |
|
print(obs) |
|
|