File size: 3,365 Bytes
a162e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gym

from gym import spaces
import numpy as np

from runner import run_episode
from redux_wrap import ReduxWrapper
from rotate_wrap import RotateWrapper
from symetry_wrap import SymetryWrapper
from sort_wrap import SortWrapper
from team_wrap import TeamWrapper
from reward_wrap import RewardWrapper


class DistriWrapper(gym.Wrapper):
    """
    :param env: (gym.Env) Gym environment that will be wrapped
    """

    def __init__(self, env):

        self.blue_deads = self.red_deads = 0
        self.nb_blues, self.nb_reds = env.nb_blues, env.nb_reds

        env.observation_space = spaces.Tuple((
            spaces.Box(low=0, high=1, shape=(self.nb_blues, 6), dtype=np.float32),
            spaces.Box(low=0, high=1, shape=(self.nb_reds, 6), dtype=np.float32),
            spaces.Box(low=0, high=1, shape=(self.nb_blues, self.nb_reds), dtype=np.float32),
            spaces.Box(low=0, high=1, shape=(self.nb_reds, self.nb_blues), dtype=np.float32)))

        env.action_space = spaces.Tuple((
            spaces.Box(low=0, high=1, shape=(self.nb_blues, 3), dtype=np.float32),
            spaces.Box(low=0, high=1, shape=(self.nb_reds, 3), dtype=np.float32)))

        # Call the parent constructor, so we can access self.env later
        super(DistriWrapper, self).__init__(env)

    def reset(self):
        """
        Reset the environment
        """
        obs = self.env.reset()
        blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs
        self.blue_deads, self.blue_deads = blue_deads, red_deads
        return blue_obs, red_obs, blues_fire, reds_fire

    def step(self, action):
        """
        :param action: ([float] or int) Action taken by the agent
        :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
        """
        obs, reward, done, info = self.env.step(action)

        blue_obs, red_obs, blues_fire, reds_fire, blue_deads, red_deads = obs
        obs = blue_obs, red_obs, blues_fire, reds_fire

        if done:  # environment decision (eg drones oob)
            return obs, reward, True, info

        if red_deads == len(red_obs):  # no more reds to fight (it could mean that they have all reached the target)
            return obs, reward, True, info

        if blue_deads == len(blue_obs):  # reds have won
            return obs, reward, True, info

        # do we have new deaths?
        new_blue_deads = blue_deads - self.blue_deads
        new_red_deads = red_deads - self.red_deads
        self.blue_deads, self.red_deads = blue_deads, red_deads

        if 0 < new_red_deads + new_blue_deads:  # we have someone killed but we still have some fight

            blues, reds = self.nb_blues - blue_deads, self.nb_reds - red_deads

            env = ReduxWrapper(self,  minus_blue=blue_deads, minus_red=red_deads)
            obs_ = env.post_obs(obs)

            env = RotateWrapper(env)
            obs_ = env.post_obs(obs_)

            env = SymetryWrapper(env)
            obs_ = env.post_obs(obs_)

            env = SortWrapper(env)
            obs_ = env.post_obs(obs_)

            env = RewardWrapper(TeamWrapper(env, is_double=True), is_double=True)
            obs_ = env.post_obs(obs_)

            _, reward, done, info = run_episode(env, obs_, blues=blues, reds=reds)

        return obs, reward, done, info