baiyanlali-zhao commited on
Commit
8be1cb6
·
1 Parent(s): c3a57cc

add src.env

Browse files
Files changed (5) hide show
  1. .gitignore +3 -3
  2. src/env/environments.py +399 -0
  3. src/env/logger.py +182 -0
  4. src/env/rfunc.py +219 -0
  5. src/env/rfuncs.py +9 -0
.gitignore CHANGED
@@ -125,11 +125,11 @@ celerybeat.pid
125
  *.sage.py
126
 
127
  # Environments
128
- .env
129
  .venv
130
- env/
131
  venv/
132
- ENV/
133
  env.bak/
134
  venv.bak/
135
 
 
125
  *.sage.py
126
 
127
  # Environments
128
+ # .env
129
  .venv
130
+ # env/
131
  venv/
132
+ # ENV/
133
  env.bak/
134
  venv.bak/
135
 
src/env/environments.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+
4
+ import gym
5
+ import numpy as np
6
+ from typing import Tuple, List, Dict, Callable, Optional
7
+
8
+ from stable_baselines3.common.env_util import make_vec_env
9
+ from stable_baselines3.common.vec_env import SubprocVecEnv
10
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvObs
11
+ from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs
12
+
13
+ from src.env.logger import InfoCollector
14
+ from src.env.rfunc import RewardFunc
15
+ from src.smb.asyncsimlt import AsycSimltPool
16
+ from src.smb.level import lvlhcat
17
+ from src.gan.gans import SAGenerator
18
+ from src.gan.gankits import *
19
+ from src.smb.proxy import MarioProxy, MarioJavaAgents
20
+ from src.utils.datastruct import RingQueue
21
+
22
+
23
+ def get_padded_obs(vecs, histlen, add_batch_dim=False):
24
+ if len(vecs) < histlen:
25
+ lack = histlen - len(vecs)
26
+ pad = [np.zeros([nz], np.float32) for _ in range(lack)]
27
+ obs = np.concatenate(pad + vecs)
28
+ else:
29
+ obs = np.concatenate(vecs)
30
+ if add_batch_dim:
31
+ obs = np.reshape(obs, [1, -1])
32
+ return obs
33
+
34
+
35
+ class SingleProcessOLGenEnv(gym.Env):
36
+ def __init__(self, rfunc, decoder: SAGenerator, eplen: int=50, device='cuda:0'):
37
+ self.rfunc = rfunc
38
+ self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
39
+ self.decoder = decoder
40
+ self.decoder.to(device)
41
+ self.hist_len = self.rfunc.get_n()
42
+ self.eplen = eplen
43
+ self.device = device
44
+ self.action_space = gym.spaces.Box(-1, 1, (nz,))
45
+ self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
46
+ # self.obs_queue = RingQueue(self.hist_len)
47
+ self.lat_vecs = []
48
+ self.simulator = MarioProxy()
49
+ pass
50
+
51
+ def step(self, action):
52
+ self.lat_vecs.append(action)
53
+ done = len(self.lat_vecs) == (self.eplen + 1)
54
+ info = {}
55
+ if done:
56
+ rewss = self.__evalute()
57
+ info['rewss'] = rewss
58
+ rewsums = [sum(items) for items in zip(*rewss.values())]
59
+ info['transitions'] = self.__process_traj(rewsums[-self.eplen:])
60
+ self.reset()
61
+ return self.getobs(), 0, done, info
62
+
63
+ def __evalute(self):
64
+ z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
65
+ # print(z.shape)
66
+ segs = process_onehot(self.decoder(z))
67
+ lvl = lvlhcat(segs)
68
+ simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
69
+ rewardss = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
70
+ return rewardss
71
+
72
+ def __process_traj(self, rewards):
73
+ obs = []
74
+ for i in range(1, len(self.lat_vecs) + 1):
75
+ ob = get_padded_obs(self.lat_vecs[max(0, i - self.hist_len): i], self.hist_len)
76
+ obs.append(ob)
77
+ traj = [(obs[i], self.lat_vecs[i+1], rewards[i], obs[i+1]) for i in range(len(self.lat_vecs) - 1)]
78
+ return traj
79
+
80
+ def reset(self):
81
+ self.lat_vecs.clear()
82
+ z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
83
+ self.lat_vecs.append(z0)
84
+ return self.getobs()
85
+
86
+ def getobs(self):
87
+ s = max(0, len(self.lat_vecs) - self.hist_len)
88
+ return get_padded_obs(self.lat_vecs[s:], self.hist_len, True)
89
+
90
+ def render(self, mode="human"):
91
+ pass
92
+
93
+ def generate_levels(self, agent, n=1, max_parallel=None):
94
+ if max_parallel is None:
95
+ max_parallel = min(n, 512)
96
+ levels = []
97
+ latvecs = []
98
+ obs_queues = [RingQueue(self.hist_len) for _ in range(max_parallel)]
99
+ while len(levels) < n:
100
+ veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
101
+ for queue, veclist in zip(obs_queues, veclists):
102
+ queue.clear()
103
+ init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
104
+ queue.push(init_latvec)
105
+ veclist.append(init_latvec)
106
+ for _ in range(self.eplen):
107
+ obs = np.stack([get_padded_obs(queue.to_list(), self.hist_len) for queue in obs_queues])
108
+ actions = agent.make_decision(obs)
109
+ for queue, veclist, action in zip(obs_queues, veclists, actions):
110
+ queue.push(action)
111
+ veclist.append(action)
112
+ for veclist in veclists:
113
+ latvecs.append(np.stack(veclist))
114
+ z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
115
+ lvl = lvlhcat(process_onehot(self.decoder(z)))
116
+ levels.append(lvl)
117
+ return levels, latvecs
118
+
119
+
120
+ class AsyncOlGenEnv:
121
+ def __init__(self, histlen, decoder: SAGenerator, eval_pool: AsycSimltPool, eplen: int=50, device='cuda:0'):
122
+ self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
123
+ self.decoder = decoder
124
+ self.decoder.to(device)
125
+ self.device = device
126
+ self.eval_pool = eval_pool
127
+ self.eplen = eplen
128
+ self.tid = 0
129
+ self.histlen = histlen
130
+
131
+ self.cur_vectraj = []
132
+ self.buffer = {}
133
+
134
+ def reset(self):
135
+ if len(self.cur_vectraj) > 0:
136
+ self.buffer[self.tid] = self.cur_vectraj
137
+ self.cur_vectraj = []
138
+ self.tid += 1
139
+ z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
140
+ self.cur_vectraj.append(z0)
141
+ return self.getobs()
142
+
143
+ def step(self, action):
144
+ self.cur_vectraj.append(action)
145
+ done = len(self.cur_vectraj) == (self.eplen + 1)
146
+ if done:
147
+ self.__submit_eval_task()
148
+ self.reset()
149
+ return self.getobs(), done
150
+
151
+ def getobs(self):
152
+ s = max(0, len(self.cur_vectraj) - self.histlen)
153
+ return get_padded_obs(self.cur_vectraj[s:], self.histlen, True)
154
+
155
+ def __submit_eval_task(self):
156
+ z = torch.tensor(np.stack(self.cur_vectraj).reshape([-1, nz, 1, 1]), device=self.device)
157
+ segs = process_onehot(self.decoder(torch.clamp(z, -1, 1)))
158
+ lvl = lvlhcat(segs)
159
+ args = (self.tid, str(lvl))
160
+ self.eval_pool.put('evaluate', args)
161
+
162
+ def refresh(self):
163
+ if self.eval_pool is not None:
164
+ self.eval_pool.refresh()
165
+
166
+ def rollout(self, close=False, wait=False) -> Tuple[List[Tuple], List[Dict[str, List]]]:
167
+ transitions, rewss = [], []
168
+ if close:
169
+ eval_res = self.eval_pool.close()
170
+ else:
171
+ eval_res = self.eval_pool.get(wait)
172
+ for tid, rewards in eval_res:
173
+ rewss.append(rewards)
174
+ rewsums = [sum(items) for items in zip(*rewards.values())]
175
+ vectraj = self.buffer.pop(tid)
176
+ transitions += self.__process_traj(vectraj, rewsums[-self.eplen:])
177
+ return transitions, rewss
178
+
179
+ def __process_traj(self, vectraj, rewards):
180
+ obs = []
181
+ for i in range(1, len(vectraj) + 1):
182
+ ob = get_padded_obs(vectraj[max(0, i - self.histlen): i], self.histlen)
183
+ obs.append(ob)
184
+ traj = [(obs[i], vectraj[i+1], rewards[i], obs[i+1]) for i in range(len(vectraj) - 1)]
185
+ return traj
186
+
187
+ def close(self):
188
+ res = self.rollout(True)
189
+ self.eval_pool = None
190
+ return res
191
+
192
+ def generate_levels(self, agent, n=1, max_parallel=None):
193
+ if max_parallel is None:
194
+ max_parallel = min(n, 512)
195
+ levels = []
196
+ latvecs = []
197
+ obs_queues = [RingQueue(self.histlen) for _ in range(max_parallel)]
198
+ while len(levels) < n:
199
+ veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
200
+ for queue, veclist in zip(obs_queues, veclists):
201
+ queue.clear()
202
+ init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
203
+ queue.push(init_latvec)
204
+ veclist.append(init_latvec)
205
+ for _ in range(self.eplen):
206
+ obs = np.stack([get_padded_obs(queue.to_list(), self.histlen) for queue in obs_queues])
207
+ actions = agent.make_decision(obs)
208
+ for queue, veclist, action in zip(obs_queues, veclists, actions):
209
+ queue.push(action)
210
+ veclist.append(action)
211
+ for veclist in veclists:
212
+ latvecs.append(np.stack(veclist))
213
+ z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
214
+ lvl = lvlhcat(process_onehot(self.decoder(z)))
215
+ levels.append(lvl)
216
+ return levels, latvecs
217
+
218
+
219
+ ######### Adopt from https://github.com/SUSTechGameAI/MFEDRL #########
220
+ class SyncOLGenWorkerEnv(gym.Env):
221
+ def __init__(self, rfunc=None, hist_len=5, eplen=25, return_lvl=False, init_one=False, play_style='Runner'):
222
+ self.rfunc = RewardFunc() if rfunc is None else rfunc
223
+ self.mario_proxy = MarioProxy() if self.rfunc.require_simlt else None
224
+ self.action_space = gym.spaces.Box(-1, 1, (nz,))
225
+ self.hist_len = hist_len
226
+ self.observation_space = gym.spaces.Box(-1, 1, (hist_len * nz,))
227
+ self.segs = []
228
+ self.latvec_archive = RingQueue(hist_len)
229
+ self.eplen = eplen
230
+ self.counter = 0
231
+ # self.repairer = DivideConquerRepairer()
232
+ self.init_one = init_one
233
+ self.backup_latvecs = None
234
+ self.backup_strsegs = None
235
+ self.return_lvl = return_lvl
236
+ self.jagent = MarioJavaAgents.__getitem__(play_style)
237
+ self.simlt_k = 80 if play_style == 'Runner' else 320
238
+
239
+ def receive(self, **kwargs):
240
+ for key in kwargs.keys():
241
+ setattr(self, key, kwargs[key])
242
+
243
+ def step(self, data):
244
+ action, strseg = data
245
+ seg = MarioLevel(strseg)
246
+ self.latvec_archive.push(action)
247
+
248
+ self.counter += 1
249
+ self.segs.append(seg)
250
+ done = self.counter >= self.eplen
251
+ if done:
252
+ full_level = lvlhcat(self.segs)
253
+ # full_level = self.repairer.repair(full_level)
254
+ w = MarioLevel.seg_width
255
+ segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
256
+ if self.mario_proxy:
257
+ raw_simlt_res = self.mario_proxy.simulate_complete(lvlhcat(segs), self.jagent, self.simlt_k)
258
+ simlt_res = MarioProxy.get_seg_infos(raw_simlt_res)
259
+ else:
260
+ simlt_res = None
261
+ rewards = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
262
+ info = {}
263
+ total_score = 0
264
+ if self.return_lvl:
265
+ info['LevelStr'] = str(full_level)
266
+ for key in rewards:
267
+ info[f'{key}_reward_list'] = rewards[key][-self.eplen:]
268
+ info[f'{key}'] = sum(rewards[key][-self.eplen:])
269
+ total_score += info[f'{key}']
270
+ info['TotalScore'] = total_score
271
+ info['EpLength'] = self.counter
272
+ else:
273
+ info = {}
274
+ return self.__get_obs(), 0, done, info
275
+
276
+ def reset(self):
277
+ self.segs.clear()
278
+ self.latvec_archive.clear()
279
+ for latvec, strseg in zip(self.backup_latvecs, self.backup_strsegs):
280
+ self.latvec_archive.push(latvec)
281
+ self.segs.append(MarioLevel(strseg))
282
+
283
+ self.backup_latvecs, self.backup_strsegs = None, None
284
+ self.counter = 0
285
+ return self.__get_obs()
286
+
287
+ def __get_obs(self):
288
+ lack = self.hist_len - len(self.latvec_archive)
289
+ pad = [np.zeros([nz], np.float32) for _ in range(lack)]
290
+ return np.concatenate([*pad, *self.latvec_archive.to_list()])
291
+
292
+ def render(self, mode='human'):
293
+ pass
294
+
295
+
296
+ class VecOLGenEnv(SubprocVecEnv):
297
+ def __init__(
298
+ self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None, hist_len=5, eplen=50,
299
+ init_one=True, log_path=None, log_itv=-1, log_targets=None, device='cuda:0'
300
+ ):
301
+ super(VecOLGenEnv, self).__init__(env_fns, start_method)
302
+ self.decoder = get_decoder(device=device)
303
+
304
+ if log_path:
305
+ self.logger = InfoCollector(log_path, log_itv, log_targets)
306
+ else:
307
+ self.logger = None
308
+ self.hist_len = hist_len
309
+ self.total_steps = 0
310
+ self.start_time = time.time()
311
+ self.eplen = eplen
312
+ self.device = device
313
+ self.init_one = init_one
314
+ self.latvec_set = np.load(getpath('smb/init_latvecs.npy'))
315
+
316
+ def step_async(self, actions: np.ndarray) -> None:
317
+ with torch.no_grad():
318
+ z = torch.tensor(actions.astype(np.float32), device=self.device).view(-1, nz, 1, 1)
319
+ segs = process_onehot(self.decoder(z))
320
+ for remote, action, seg in zip(self.remotes, actions, segs):
321
+ remote.send(("step", (action, str(seg))))
322
+ self.waiting = True
323
+
324
+ def step_wait(self) -> VecEnvStepReturn:
325
+ self.total_steps += self.num_envs
326
+ results = [remote.recv() for remote in self.remotes]
327
+ self.waiting = False
328
+ obs, rews, dones, infos = zip(*results)
329
+
330
+ envs_to_send = [i for i in range(self.num_envs) if dones[i]]
331
+ self.send_reset_data(envs_to_send)
332
+
333
+ if self.logger is not None:
334
+ for i in range(self.num_envs):
335
+ if infos[i]:
336
+ infos[i]['TotalSteps'] = self.total_steps
337
+ infos[i]['TimePassed'] = time.time() - self.start_time
338
+ self.logger.on_step(dones, infos)
339
+ return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
340
+
341
+ def reset(self) -> VecEnvObs:
342
+ self.send_reset_data()
343
+ for remote in self.remotes:
344
+ remote.send(("reset", None))
345
+ obs = [remote.recv() for remote in self.remotes]
346
+ self.send_reset_data()
347
+ return _flatten_obs(obs, self.observation_space)
348
+
349
+ def send_reset_data(self, env_ids=None):
350
+ if env_ids is None:
351
+ env_ids = [*range(self.num_envs)]
352
+ target_remotes = self._get_target_remotes(env_ids)
353
+
354
+ n_inits = 1 if self.init_one else self.hist_len
355
+ # latvecs = [sample_latvec(n_inits, tensor=False) for _ in range(len(env_ids))]
356
+
357
+ latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
358
+ with torch.no_grad():
359
+ segss = [[] for _ in range(len(env_ids))]
360
+ for i in range(len(env_ids)):
361
+ z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
362
+ # print(self.decoder(z).shape)
363
+ segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
364
+ for remote, latvec, segs in zip(target_remotes, latvecs, segss):
365
+ kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
366
+ remote.send(("env_method", ('receive', [], kwargs)))
367
+ for remote in target_remotes:
368
+ remote.recv()
369
+
370
+ def close(self) -> None:
371
+ super().close()
372
+ if self.logger is not None:
373
+ self.logger.close()
374
+
375
+
376
+ def make_vec_offrew_env(
377
+ num_envs, rfunc=None, log_path=None, eplen=25, log_itv=-1, hist_len=5, init_one=True,
378
+ play_style='Runner', device='cuda:0', log_targets=None, return_lvl=False
379
+ ):
380
+ return make_vec_env(
381
+ SyncOLGenWorkerEnv, n_envs=num_envs, vec_env_cls=VecOLGenEnv,
382
+ vec_env_kwargs={
383
+ 'log_path': log_path,
384
+ 'log_itv': log_itv,
385
+ 'log_targets': log_targets,
386
+ 'device': device,
387
+ 'eplen': eplen,
388
+ 'hist_len': hist_len,
389
+ 'init_one': init_one
390
+ },
391
+ env_kwargs={
392
+ 'rfunc': rfunc,
393
+ 'eplen': eplen,
394
+ 'return_lvl': return_lvl,
395
+ 'play_style': play_style,
396
+ 'hist_len': hist_len,
397
+ 'init_one': init_one
398
+ }
399
+ )
src/env/logger.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import csv
4
+ import numpy as np
5
+ from itertools import product
6
+ from src.smb.level import save_batch
7
+ from src.utils.filesys import getpath
8
+
9
+ ############### Loggers for async environment ###############
10
+ class AsyncCsvLogger:
11
+ def __init__(self, target, rfunc, buffer_size=50):
12
+ self.rterms = tuple(term.get_name() for term in rfunc.terms)
13
+ self.cols = ('steps', *self.rterms, 'reward_sum', 'time', 'trans', 'updates', '')
14
+ self.buffer = []
15
+ self.buffer_size = buffer_size
16
+ self.ftarget = open(getpath(target), 'w', newline='')
17
+ self.writer = csv.writer(self.ftarget)
18
+ self.writer.writerow(self.cols)
19
+
20
+ def on_episode(self, **kwargs):
21
+ for rews in kwargs['rewss']:
22
+ rews_list = [sum(rews[key]) for key in self.rterms]
23
+ self.buffer.append([
24
+ kwargs['steps'], *rews_list, sum(rews_list),
25
+ kwargs['time'], kwargs['trans'], kwargs['updates']
26
+ ])
27
+ self.__try_write()
28
+ if kwargs['close']:
29
+ self.close()
30
+
31
+ def __try_write(self):
32
+ if len(self.buffer) < self.buffer_size:
33
+ return
34
+ self.writer.writerows(self.buffer)
35
+ self.ftarget.flush()
36
+ self.buffer.clear()
37
+
38
+ def close(self):
39
+ self.writer.writerows(self.buffer)
40
+ self.ftarget.close()
41
+ pass
42
+
43
+
44
+ class AsyncStdLogger:
45
+ def __init__(self, rfunc, itv=2000, path=''):
46
+ self.rterms = tuple(term.get_name() for term in rfunc.terms)
47
+ self.rews = {rterm: 0. for rterm in self.rterms}
48
+ self.n = 0
49
+ self.itv = itv
50
+ self.horizon = itv
51
+ if not len(path):
52
+ self.f = None
53
+ else:
54
+ self.f = open(getpath(path), 'w')
55
+ self.last_steps = 0
56
+ self.last_trans = 0
57
+ self.last_updates = 0
58
+ self.buffer = []
59
+ pass
60
+
61
+ def on_episode(self, **kwargs):
62
+ newrews = {k: self.rews[k] for k in self.rterms}
63
+ for rews, k in product(kwargs['rewss'], self.rterms):
64
+ newrews[k] = newrews[k] + sum(rews[k])
65
+ self.rews = newrews
66
+ self.n += len(kwargs['rewss'])
67
+ if kwargs['steps'] >= self.horizon or kwargs['close']:
68
+ self.__output(**kwargs)
69
+ self.horizon += self.itv
70
+ self.rews = {k: 0 for k in self.rews.keys()}
71
+ self.n = 0
72
+ self.last_steps = kwargs['steps']
73
+ self.last_trans = kwargs['trans']
74
+ self.last_updates = kwargs['updates']
75
+ if kwargs['close'] and self.f is not None:
76
+ self.f.close()
77
+
78
+ def __output(self, **kwargs):
79
+ steps = kwargs['steps']
80
+ if kwargs['close']:
81
+ head = '-' * 20 + 'Closing rollouts' + '-' * 20
82
+ else:
83
+ head = '-' * 20 + f'Rollout of {self.last_steps}-{steps} steps' + '-' * 20
84
+ self.buffer.append(head)
85
+ rsum = 0
86
+ for t in self.rterms:
87
+ v = 0 if self.n == 0 else self.rews[t] / self.n
88
+ rsum += v
89
+ self.buffer.append(f'{t}: {v:.2f}')
90
+ self.buffer.append(f'Reward sum: {rsum: .2f}')
91
+ self.buffer.append('Time elapsed: %.1fs' % kwargs['time'])
92
+ self.buffer.append('Transitions collected: %d (%d in total)' % (kwargs['trans'] - self.last_trans, kwargs['trans']))
93
+ self.buffer.append('Number of updates: %d (%d in total)' % (kwargs['updates']- self.last_updates, kwargs['updates']))
94
+ if self.f is None:
95
+ print('\n'.join(self.buffer) + '\n')
96
+ else:
97
+ self.f.write('\n'.join(self.buffer) + '\n')
98
+ self.f.flush()
99
+ self.buffer.clear()
100
+ pass
101
+
102
+ def __reset(self):
103
+ pass
104
+
105
+
106
+ class GenResLogger:
107
+ def __init__(self, root_path, k, itv=5000):
108
+ self.k = k
109
+ self.itv = itv
110
+ self.horizon = 0
111
+ self.path = getpath(f'{root_path}/gen_log')
112
+ os.makedirs(self.path, exist_ok=True)
113
+
114
+ def on_episode(self, env, agent, steps):
115
+ if steps >= self.horizon:
116
+ lvls, vectraj = env.generate_levels(agent, self.k)
117
+ # np.save(f'{self.path}/step{steps}', vectraj)
118
+ if len(lvls):
119
+ save_batch(lvls, f'{self.path}/step{steps}')
120
+ self.horizon += self.itv
121
+ pass
122
+ pass
123
+
124
+
125
+ ############### Loggers for sync environment, from https://github.com/SUSTechGameAI/MFEDRL ###############
126
+ class InfoCollector:
127
+ ignored_keys = {'episode', 'terminal_observation'}
128
+ save_itv = 1000
129
+
130
+ def __init__(self, path, log_itv=100, log_targets=None):
131
+ self.data = []
132
+ self.path = path
133
+ self.msg_itv = log_itv
134
+ self.time_before_save = InfoCollector.save_itv
135
+ self.msg_ptr = 0
136
+ self.log_targets = [] if log_targets is None else log_targets
137
+ if 'file' in log_targets:
138
+ with open(f'{self.path}/log.txt', 'w') as f:
139
+ f.write('')
140
+ self.recent_time = 0
141
+
142
+ def on_step(self, dones, infos):
143
+ for done, info in zip(dones, infos):
144
+ if done:
145
+ self.data.append({
146
+ key: val for key, val in info.items()
147
+ if key not in InfoCollector.ignored_keys and 'reward_list' not in key
148
+ })
149
+ self.time_before_save -= 1
150
+ if self.time_before_save <= 0:
151
+ with open(f'{self.path}/ep_infos.json', 'w') as f:
152
+ json.dump(self.data, f)
153
+ self.time_before_save += InfoCollector.save_itv
154
+
155
+ if self.log_targets and 0 < self.msg_itv <= (len(self.data) - self.msg_ptr):
156
+ keys = set(self.data[-1].keys()) - {'TotalSteps', 'TimePassed', 'TotalScore', 'EpLen'}
157
+
158
+ msg = '%sTotal steps: %d%s\n' % ('-' * 16, self.data[-1]['TotalSteps'], '-' * 16)
159
+ msg += 'Time passed: %ds\n' % self.data[-1]['TimePassed']
160
+ t = self.data[-1]['TimePassed'] - self.recent_time
161
+ self.recent_time = self.data[-1]['TimePassed']
162
+ f = sum(item['EpLength'] for item in self.data[self.msg_ptr:])
163
+ msg += 'fps: %.3g\n' % (f/t)
164
+ for key in keys:
165
+ values = [item[key] for item in self.data[self.msg_ptr:]]
166
+ values = np.array(values)
167
+ msg += '%s: %.2f +- %.2f\n' % (key, values.mean(), values.std())
168
+ values = [item['TotalScore'] for item in self.data[self.msg_ptr:]]
169
+ values = np.array(values)
170
+ msg += 'TotalScore: %.2f +- %.2f\n' % (values.mean(), values.std())
171
+
172
+ if 'file' in self.log_targets:
173
+ with open(f'{self.path}/log.txt', 'a') as f:
174
+ f.write(msg + '\n')
175
+ if 'std' in self.log_targets:
176
+ print(msg)
177
+ self.msg_ptr = len(self.data)
178
+ pass
179
+
180
+ def close(self):
181
+ with open(self.path, 'w') as f:
182
+ json.dump(self.data, f)
src/env/rfunc.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from math import ceil
3
+ from abc import abstractmethod
4
+ from src.utils.mymath import a_clip
5
+ from src.smb.level import *
6
+
7
+ defaults = {'n': 5, 'gl': 0.14, 'gg': 0.30, 'wl': 2, 'wg': 10}
8
+
9
+
10
+ class RewardFunc:
11
+ def __init__(self, *args):
12
+ self.terms = args
13
+ self.require_simlt = any(term.require_simlt for term in self.terms)
14
+
15
+ def get_rewards(self, **kwargs):
16
+ return {
17
+ term.get_name(): term.compute_rewards(**kwargs)
18
+ for term in self.terms
19
+ }
20
+
21
+ def get_n(self):
22
+ n = 1
23
+ for term in self.terms:
24
+ try:
25
+ n = max(n, term.n)
26
+ except AttributeError:
27
+ pass
28
+ return n
29
+
30
+ def __str__(self):
31
+ return 'Reward Function:\n' + ',\n'.join('\t' + str(term) for term in self.terms)
32
+
33
+
34
+ class RewardTerm:
35
+ def __init__(self, require_simlt):
36
+ self.require_simlt = require_simlt
37
+
38
+ def get_name(self):
39
+ return self.__class__.__name__
40
+
41
+ @abstractmethod
42
+ def compute_rewards(self, **kwargs):
43
+ pass
44
+
45
+
46
+ class Playability(RewardTerm):
47
+ def __init__(self, magnitude=1):
48
+ super(Playability, self).__init__(True)
49
+ self.magnitude=magnitude
50
+
51
+ def compute_rewards(self, **kwargs):
52
+ simlt_res = kwargs['simlt_res']
53
+ return [0 if item['playable'] else -self.magnitude for item in simlt_res[1:]]
54
+
55
+ def __str__(self):
56
+ return f'{self.magnitude} * Playability'
57
+
58
+
59
+ class MeanDivergenceFun(RewardTerm):
60
+ def __init__(self, goal_div, n=defaults['n'], s=8):
61
+ super().__init__(False)
62
+ self.l = goal_div * 0.26 / 0.6
63
+ self.u = goal_div * 0.94 / 0.6
64
+ self.n = n
65
+ self.s = s
66
+
67
+ def compute_rewards(self, **kwargs):
68
+ segs = kwargs['segs']
69
+ rewards = []
70
+ for i in range(1, len(segs)):
71
+ seg = segs[i]
72
+ histroy = lvlhcat(segs[max(0, i - self.n): i])
73
+ k = 0
74
+ divergences = []
75
+ while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
76
+ cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
77
+ # print(i, nd, cmp_seg.shape)
78
+ divergences.append(tile_pattern_js_div(seg, cmp_seg))
79
+ k += 1
80
+ mean_d = sum(divergences) / len(divergences)
81
+ if mean_d < self.l:
82
+ rewards.append(-(mean_d - self.l) ** 2)
83
+ elif mean_d > self.u:
84
+ rewards.append(-(mean_d - self.u) ** 2)
85
+ else:
86
+ rewards.append(0)
87
+ return rewards
88
+
89
+
90
+ class SACNovelty(RewardTerm):
91
+ def __init__(self, magnitude, goal_div, require_simlt, n):
92
+ super().__init__(require_simlt)
93
+ self.g = goal_div
94
+ self.magnitude = magnitude
95
+ self.n = n
96
+
97
+ def compute_rewards(self, **kwargs):
98
+ n_segs = len(kwargs['segs'])
99
+ rewards = []
100
+ for i in range(1, n_segs):
101
+ reward = 0
102
+ r_sum = 0
103
+ for k in range(1, self.n + 1):
104
+ j = i - k
105
+ if j < 0:
106
+ break
107
+ r = 1 - k / (self.n + 1)
108
+ r_sum += r
109
+ reward += a_clip(self.disim(i, j, **kwargs), self.g, r)
110
+ rewards.append(reward * self.magnitude / r_sum)
111
+ return rewards
112
+
113
+ @abstractmethod
114
+ def disim(self, i, j, **kwargs):
115
+ pass
116
+
117
+
118
+ class LevelSACN(SACNovelty):
119
+ def __init__(self, magnitude=1, g=defaults['gl'], w=defaults['wl'], n=defaults['n']):
120
+ super(LevelSACN, self).__init__(magnitude, g, False, n)
121
+ self.w = w
122
+
123
+ def disim(self, i, j, **kwargs):
124
+ segs = kwargs['segs']
125
+ seg1, seg2 = segs[i], segs[j]
126
+ return tile_pattern_js_div(seg1, seg2, self.w)
127
+
128
+ def __str__(self):
129
+ s = f'{self.magnitude} * LevelSACN(g={self.g:.3g}, w={self.w}, n={self.n})'
130
+ return s
131
+
132
+
133
+ class GameplaySACN(SACNovelty):
134
+ def __init__(self, magnitude=1, g=defaults['gg'], w=defaults['wg'], n=defaults['n']):
135
+ super(GameplaySACN, self).__init__(magnitude, g, True, n)
136
+ self.w = w
137
+
138
+ def disim(self, i, j, **kwargs):
139
+ simlt_res = kwargs['simlt_res']
140
+ trace1, trace2 = simlt_res[i]['trace'], simlt_res[j]['trace']
141
+ return trace_div(trace1, trace2, self.w)
142
+
143
+ def __str__(self):
144
+ s = f'{self.magnitude} * GameplaySACN(g={self.g:.3g}, w={self.w}, n={self.n})'
145
+ return s
146
+
147
+
148
+ class Fun(RewardTerm):
149
+ def __init__(self, magnitude=1., num_windows=3, lb=0.26, ub=0.94, stride=8):
150
+ super().__init__(False)
151
+ self.lb, self.ub = lb, ub
152
+ self.magnitude = magnitude
153
+ self.stride = stride
154
+ self.num_windows = num_windows
155
+ self.n = ceil(num_windows * stride / MarioLevel.seg_width - 1e-8)
156
+
157
+ def compute_rewards(self, **kwargs):
158
+ n_segs = len(kwargs['segs'])
159
+ lvl = lvlhcat(kwargs['segs'])
160
+ W = MarioLevel.seg_width
161
+ rewards = []
162
+ for i in range(1, n_segs):
163
+ seg = lvl[:, W*i: W*(i+1)]
164
+ divs = []
165
+ for k in range(0, self.num_windows + 1):
166
+ s = W * i - k * self.stride
167
+ if s < 0:
168
+ break
169
+ cmp_seg = lvl[:, s:s+W]
170
+ divs.append(tile_pattern_kl_div(seg, cmp_seg))
171
+ mean_div = np.mean(divs)
172
+ rew = 0
173
+ if mean_div > self.ub:
174
+ rew = -(self.ub - mean_div) ** 2
175
+ if mean_div < self.lb:
176
+ rew = -(self.lb - mean_div) ** 2
177
+ rewards.append(rew * self.magnitude)
178
+ return rewards
179
+
180
+ def __str__(self):
181
+ s = f'{self.magnitude} * Fun(lb={self.lb:.2f}, ub={self.ub:.2f}, n={self.num_windows}, stride={self.stride})'
182
+ return s
183
+
184
+
185
+ class HistoricalDeviation(RewardTerm):
186
+ def __init__(self, magnitude=1., m=3, n=10):
187
+ super().__init__(False)
188
+ self.magnitude = magnitude
189
+ self.m = m
190
+ self.n = n
191
+
192
+ def compute_rewards(self, **kwargs):
193
+ segs = kwargs['segs']
194
+ n_segs = len(kwargs['segs'])
195
+ rewards = []
196
+ for i in range(1, n_segs):
197
+ divs = []
198
+ for k in range(1, self.n+1):
199
+ j = i - k
200
+ if j < 0:
201
+ break
202
+ divs.append(tile_pattern_kl_div(segs[i], segs[j]))
203
+ divs.sort()
204
+ m = min(i, self.m)
205
+ rew = np.mean(divs[:m])
206
+ rewards.append(rew * self.magnitude)
207
+ return rewards
208
+
209
+ def __str__(self):
210
+ return f'{self.magnitude} * HistoricalDeviation(m={self.m}, n={self.n})'
211
+
212
+
213
+ if __name__ == '__main__':
214
+ # print(type(ceil(0.2)))
215
+ # arr = [1., 3., 2.]
216
+ # arr.sort()
217
+ # print(arr)
218
+ rfunc = HistoricalDeviation()
219
+
src/env/rfuncs.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from src.env.rfunc import *
2
+
3
+ default = lambda : RewardFunc(LevelSACN(), GameplaySACN(), Playability())
4
+ fp = lambda : RewardFunc(Fun(200), Playability(10))
5
+ hp = lambda : RewardFunc(HistoricalDeviation(1), Playability(1))
6
+ fhp = lambda : RewardFunc(Fun(30, num_windows=21), HistoricalDeviation(3), Playability(3))
7
+ lp = lambda : RewardFunc(LevelSACN(2.), Playability(2.))
8
+ gp = lambda : RewardFunc(GameplaySACN(2.), Playability(2.))
9
+ lgp = lambda : RewardFunc(LevelSACN(), GameplaySACN(), Playability())