File size: 4,066 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import Any, Union, List
import copy
import numpy as np
from numpy import dtype
import gym
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY
from .mujoco_multi import MujocoMulti


@ENV_REGISTRY.register('mujoco_multi')
class MujocoEnv(BaseEnv):

    def __init__(self, cfg: dict) -> None:
        self._cfg = cfg
        self._init_flag = False

    def reset(self) -> np.ndarray:
        if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
            np_seed = 100 * np.random.randint(1, 1000)
            self._cfg.seed = self._seed + np_seed
        elif hasattr(self, '_seed'):
            self._cfg.seed = self._seed
        if not self._init_flag:
            self._env = MujocoMulti(env_args=self._cfg)
            self._init_flag = True
        obs = self._env.reset()
        self._eval_episode_return = 0.

        # TODO:
        # self.env_info for scenario='Ant-v2', agent_conf="2x4d",
        # {'state_shape': 2, 'obs_shape': 54,...}
        # 'state_shape' is wrong, it should be 111
        self.env_info = self._env.get_env_info()
        # self._env.observation_space[agent].shape equals above 'state_shape'

        self._num_agents = self.env_info['n_agents']
        self._agents = [i for i in range(self._num_agents)]
        self._observation_space = gym.spaces.Dict(
            {
                'agent_state': gym.spaces.Box(
                    low=float("-inf"), high=float("inf"), shape=obs['agent_state'].shape, dtype=np.float32
                ),
                'global_state': gym.spaces.Box(
                    low=float("-inf"), high=float("inf"), shape=obs['global_state'].shape, dtype=np.float32
                ),
            }
        )
        self._action_space = gym.spaces.Dict({agent: self._env.action_space[agent] for agent in self._agents})
        single_agent_obs_space = self._env.action_space[self._agents[0]]
        if isinstance(single_agent_obs_space, gym.spaces.Box):
            self._action_dim = single_agent_obs_space.shape
        elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
            self._action_dim = (single_agent_obs_space.n, )
        else:
            raise Exception('Only support `Box` or `Discrte` obs space for single agent.')
        self._reward_space = gym.spaces.Dict(
            {
                agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
                for agent in self._agents
            }
        )

        return obs

    def close(self) -> None:
        if self._init_flag:
            self._env.close()
        self._init_flag = False

    def seed(self, seed: int, dynamic_seed: bool = True) -> None:
        self._seed = seed
        self._dynamic_seed = dynamic_seed
        np.random.seed(self._seed)

    def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
        action = to_ndarray(action)
        obs, rew, done, info = self._env.step(action)
        self._eval_episode_return += rew
        rew = to_ndarray([rew])  # wrapped to be transfered to a array with shape (1,)
        if done:
            info['eval_episode_return'] = self._eval_episode_return
        return BaseEnvTimestep(obs, rew, done, info)

    def random_action(self) -> np.ndarray:
        random_action = self.action_space.sample()
        random_action = to_ndarray([random_action], dtype=np.int64)
        return random_action

    @property
    def num_agents(self) -> Any:
        return self._num_agents

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self._observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self._action_space

    @property
    def reward_space(self) -> gym.spaces.Space:
        return self._reward_space

    def __repr__(self) -> str:
        return "DI-engine Multi-agent Mujoco Env({})".format(self._cfg.env_id)