|
import gym |
|
|
|
from gym import spaces |
|
import numpy as np |
|
|
|
from runner import run_episode |
|
from redux_wrap import ReduxWrapper |
|
from rotate_wrap import RotateWrapper |
|
from symetry_wrap import SymetryWrapper |
|
from sort_wrap import SortWrapper |
|
from team_wrap import TeamWrapper |
|
from reward_wrap import RewardWrapper |
|
|
|
|
|
class DistriWrapper(gym.Wrapper): |
|
""" |
|
:param env: (gym.Env) Gym environment that will be wrapped |
|
""" |
|
|
|
def __init__(self, env): |
|
|
|
self.blue_deads = self.red_deads = 0 |
|
self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds |
|
|
|
env.observation_space = spaces.Tuple(( |
|
spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32))) |
|
|
|
env.action_space = spaces.Tuple(( |
|
spaces.Box(low=0, high=1, shape=(self.nb_blues, 3), dtype=np.float32), |
|
spaces.Box(low=0, high=1, shape=(self.nb_reds, 3), dtype=np.float32))) |
|
|
|
|
|
super(DistriWrapper, self).__init__(env) |
|
|
|
def reset(self): |
|
""" |
|
Reset the environment |
|
""" |
|
obs = self.env.reset() |
|
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs |
|
self.blue_deads, self.blue_deads = blue_deads, red_deads |
|
return blue_obs, red_obs, blues_fire, reds_fire |
|
|
|
def step(self, action): |
|
""" |
|
:param action: ([float] or int) Action taken by the agent |
|
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations |
|
""" |
|
obs, reward, done, info = self.env.step(action) |
|
|
|
blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs |
|
obs = blue_obs, red_obs, blues_fire, reds_fire |
|
|
|
if done: |
|
return obs, reward, True, info |
|
|
|
if red_deads == len(red_obs): |
|
return obs, reward, True, info |
|
|
|
if blue_deads == len(blue_obs): |
|
return obs, reward, True, info |
|
|
|
|
|
new_blue_deads = blue_deads - self.blue_deads |
|
new_red_deads = red_deads - self.red_deads |
|
self.blue_deads, self.red_deads = blue_deads, red_deads |
|
|
|
if 0 < new_red_deads + new_blue_deads: |
|
|
|
blues, reds = self.nb_blues - blue_deads, self.nb_reds - red_deads |
|
|
|
env = ReduxWrapper(self, minus_blue=blue_deads, minus_red=red_deads) |
|
obs_ = env.post_obs(obs) |
|
|
|
env = RotateWrapper(env) |
|
obs_ = env.post_obs(obs_) |
|
|
|
env = SymetryWrapper(env) |
|
obs_ = env.post_obs(obs_) |
|
|
|
env = SortWrapper(env) |
|
obs_ = env.post_obs(obs_) |
|
|
|
env = RewardWrapper(TeamWrapper(env, is_double=True), is_double=True) |
|
obs_ = env.post_obs(obs_) |
|
|
|
_, reward, done, info = run_episode(env, obs_, blues=blues, reds=reds) |
|
|
|
return obs, reward, done, info |
|
|