File size: 8,258 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import random
import time
from collections import namedtuple
import pytest
import torch
import numpy as np
from easydict import EasyDict
from functools import partial
import gym

from ding.envs.env.base_env import BaseEnvTimestep
from ding.envs.env_manager.base_env_manager import EnvState
from ding.envs.env_manager import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
from ding.torch_utils import to_tensor, to_ndarray, to_list
from ding.utils import deep_merge_dicts


class FakeEnv(object):

    def __init__(self, cfg):
        self._scale = cfg.scale
        self._target_time = random.randint(3, 6) * self._scale
        self._current_time = 0
        self._name = cfg['name']
        self._id = time.time()
        self._stat = None
        self._seed = 0
        self._data_count = 0
        self.timeout_flag = False
        self._launched = False
        self._state = EnvState.INIT
        self._dead_once = False
        self.observation_space = gym.spaces.Box(
            low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32
        )
        self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32)
        self.reward_space = gym.spaces.Box(
            low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32
        )

    def reset(self, stat=None):
        if isinstance(stat, str) and stat == 'error':
            self.dead()
        if isinstance(stat, str) and stat == 'error_once':
            # Die on every two reset with error_once stat.
            if self._dead_once:
                self._dead_once = False
                self.dead()
            else:
                self._dead_once = True
        if isinstance(stat, str) and stat == "wait":
            if self.timeout_flag:  # after step(), the reset can hall with status of timeout
                time.sleep(5)
        if isinstance(stat, str) and stat == "block":
            self.block()

        self._launched = True
        self._current_time = 0
        self._stat = stat
        self._state = EnvState.RUN
        return to_ndarray(torch.randn(3))

    def step(self, action):
        assert self._launched
        assert not self._state == EnvState.ERROR
        self.timeout_flag = True  # after one step, enable timeout flag
        if isinstance(action, str) and action == 'error':
            self.dead()
        if isinstance(action, str) and action == 'catched_error':
            return BaseEnvTimestep(None, None, True, {'abnormal': True})
        if isinstance(action, str) and action == "wait":
            if self.timeout_flag:  # after step(), the reset can hall with status of timeout
                time.sleep(3)
        if isinstance(action, str) and action == 'block':
            self.block()
        obs = to_ndarray(torch.randn(3))
        reward = to_ndarray(torch.randint(0, 2, size=[1]).numpy())
        done = self._current_time >= self._target_time
        if done:
            self._state = EnvState.DONE
        simulation_time = random.uniform(0.5, 1) * self._scale
        info = {'name': self._name, 'time': simulation_time, 'tgt': self._target_time, 'cur': self._current_time}
        time.sleep(simulation_time)
        self._current_time += simulation_time
        self._data_count += 1
        return BaseEnvTimestep(obs, reward, done, info)

    def dead(self):
        self._state = EnvState.ERROR
        raise RuntimeError("env error, current time {}".format(self._current_time))

    def block(self):
        self._state = EnvState.ERROR
        time.sleep(1000)

    def close(self):
        self._launched = False
        self._state = EnvState.INIT

    def seed(self, seed):
        self._seed = seed

    @property
    def name(self):
        return self._name

    @property
    def time_id(self):
        return self._id

    def user_defined(self):
        pass

    def __repr__(self):
        return self._name


class FakeAsyncEnv(FakeEnv):

    def reset(self, stat=None):
        super().reset(stat)
        time.sleep(random.randint(1, 3) * self._scale)
        return to_ndarray(torch.randn(3))


class FakeGymEnv(FakeEnv):

    def __init__(self, cfg):
        super().__init__(cfg)
        self.metadata = "fake metadata"
        self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(4, ), dtype=np.float32)

    def random_action(self) -> np.ndarray:
        random_action = self.action_space.sample()
        if isinstance(random_action, np.ndarray):
            pass
        elif isinstance(random_action, int):
            random_action = to_ndarray([random_action], dtype=np.int64)
        elif isinstance(random_action, dict):
            random_action = to_ndarray(random_action)
        else:
            raise TypeError(
                '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
                    type(random_action), random_action
                )
            )
        return random_action


class FakeModel(object):

    def forward(self, obs):
        if random.random() > 0.5:
            return {k: [] for k in obs}
        else:
            env_num = len(obs)
            exec_env = random.randint(1, env_num + 1)
            keys = list(obs.keys())[:exec_env]
            return {k: [] for k in keys}


@pytest.fixture(scope='class')
def setup_model_type():
    return FakeModel


def get_base_manager_cfg(env_num=3):
    manager_cfg = {
        'env_cfg': [{
            'name': 'name{}'.format(i),
            'scale': 1.0,
        } for i in range(env_num)],
        'episode_num': 2,
        'reset_timeout': 10,
        'step_timeout': 8,
        'max_retry': 5,
    }
    return EasyDict(manager_cfg)


def get_subprecess_manager_cfg(env_num=3):
    manager_cfg = {
        'env_cfg': [{
            'name': 'name{}'.format(i),
            'scale': 1.0,
        } for i in range(env_num)],
        'episode_num': 2,
        #'step_timeout': 8,
        #'reset_timeout': 10,
        'connect_timeout': 8,
        'step_timeout': 5,
        'max_retry': 2,
    }
    return EasyDict(manager_cfg)


def get_gym_vector_manager_cfg(env_num=3):
    manager_cfg = {
        'env_cfg': [{
            'name': 'name{}'.format(i),
        } for i in range(env_num)],
        'episode_num': 2,
        'connect_timeout': 8,
        'step_timeout': 5,
        'max_retry': 2,
        'share_memory': True
    }
    return EasyDict(manager_cfg)


@pytest.fixture(scope='function')
def setup_base_manager_cfg():
    manager_cfg = get_base_manager_cfg(3)
    env_cfg = manager_cfg.pop('env_cfg')
    manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
    return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg))


@pytest.fixture(scope='function')
def setup_fast_base_manager_cfg():
    manager_cfg = get_base_manager_cfg(3)
    env_cfg = manager_cfg.pop('env_cfg')
    for e in env_cfg:
        e['scale'] = 0.1
    manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
    return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg))


@pytest.fixture(scope='function')
def setup_sync_manager_cfg():
    manager_cfg = get_subprecess_manager_cfg(3)
    env_cfg = manager_cfg.pop('env_cfg')
    # TODO(nyz) test fail when shared_memory = True
    manager_cfg['shared_memory'] = False
    manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
    return deep_merge_dicts(SyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg))


@pytest.fixture(scope='function')
def setup_async_manager_cfg():
    manager_cfg = get_subprecess_manager_cfg(3)
    env_cfg = manager_cfg.pop('env_cfg')
    manager_cfg['env_fn'] = [partial(FakeAsyncEnv, cfg=c) for c in env_cfg]
    manager_cfg['shared_memory'] = False
    return deep_merge_dicts(AsyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg))


@pytest.fixture(scope='function')
def setup_gym_vector_manager_cfg():
    manager_cfg = get_subprecess_manager_cfg(3)
    env_cfg = manager_cfg.pop('env_cfg')
    manager_cfg['env_fn'] = [partial(FakeGymEnv, cfg=c) for c in env_cfg]
    manager_cfg['shared_memory'] = False
    return EasyDict(manager_cfg)