Spaces:
Sleeping
Sleeping
baiyanlali-zhao
commited on
Commit
·
8be1cb6
1
Parent(s):
c3a57cc
add src.env
Browse files- .gitignore +3 -3
- src/env/environments.py +399 -0
- src/env/logger.py +182 -0
- src/env/rfunc.py +219 -0
- src/env/rfuncs.py +9 -0
.gitignore
CHANGED
@@ -125,11 +125,11 @@ celerybeat.pid
|
|
125 |
*.sage.py
|
126 |
|
127 |
# Environments
|
128 |
-
.env
|
129 |
.venv
|
130 |
-
env/
|
131 |
venv/
|
132 |
-
ENV/
|
133 |
env.bak/
|
134 |
venv.bak/
|
135 |
|
|
|
125 |
*.sage.py
|
126 |
|
127 |
# Environments
|
128 |
+
# .env
|
129 |
.venv
|
130 |
+
# env/
|
131 |
venv/
|
132 |
+
# ENV/
|
133 |
env.bak/
|
134 |
venv.bak/
|
135 |
|
src/env/environments.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import numpy as np
|
6 |
+
from typing import Tuple, List, Dict, Callable, Optional
|
7 |
+
|
8 |
+
from stable_baselines3.common.env_util import make_vec_env
|
9 |
+
from stable_baselines3.common.vec_env import SubprocVecEnv
|
10 |
+
from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvObs
|
11 |
+
from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs
|
12 |
+
|
13 |
+
from src.env.logger import InfoCollector
|
14 |
+
from src.env.rfunc import RewardFunc
|
15 |
+
from src.smb.asyncsimlt import AsycSimltPool
|
16 |
+
from src.smb.level import lvlhcat
|
17 |
+
from src.gan.gans import SAGenerator
|
18 |
+
from src.gan.gankits import *
|
19 |
+
from src.smb.proxy import MarioProxy, MarioJavaAgents
|
20 |
+
from src.utils.datastruct import RingQueue
|
21 |
+
|
22 |
+
|
23 |
+
def get_padded_obs(vecs, histlen, add_batch_dim=False):
|
24 |
+
if len(vecs) < histlen:
|
25 |
+
lack = histlen - len(vecs)
|
26 |
+
pad = [np.zeros([nz], np.float32) for _ in range(lack)]
|
27 |
+
obs = np.concatenate(pad + vecs)
|
28 |
+
else:
|
29 |
+
obs = np.concatenate(vecs)
|
30 |
+
if add_batch_dim:
|
31 |
+
obs = np.reshape(obs, [1, -1])
|
32 |
+
return obs
|
33 |
+
|
34 |
+
|
35 |
+
class SingleProcessOLGenEnv(gym.Env):
|
36 |
+
def __init__(self, rfunc, decoder: SAGenerator, eplen: int=50, device='cuda:0'):
|
37 |
+
self.rfunc = rfunc
|
38 |
+
self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
|
39 |
+
self.decoder = decoder
|
40 |
+
self.decoder.to(device)
|
41 |
+
self.hist_len = self.rfunc.get_n()
|
42 |
+
self.eplen = eplen
|
43 |
+
self.device = device
|
44 |
+
self.action_space = gym.spaces.Box(-1, 1, (nz,))
|
45 |
+
self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
|
46 |
+
# self.obs_queue = RingQueue(self.hist_len)
|
47 |
+
self.lat_vecs = []
|
48 |
+
self.simulator = MarioProxy()
|
49 |
+
pass
|
50 |
+
|
51 |
+
def step(self, action):
|
52 |
+
self.lat_vecs.append(action)
|
53 |
+
done = len(self.lat_vecs) == (self.eplen + 1)
|
54 |
+
info = {}
|
55 |
+
if done:
|
56 |
+
rewss = self.__evalute()
|
57 |
+
info['rewss'] = rewss
|
58 |
+
rewsums = [sum(items) for items in zip(*rewss.values())]
|
59 |
+
info['transitions'] = self.__process_traj(rewsums[-self.eplen:])
|
60 |
+
self.reset()
|
61 |
+
return self.getobs(), 0, done, info
|
62 |
+
|
63 |
+
def __evalute(self):
|
64 |
+
z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
|
65 |
+
# print(z.shape)
|
66 |
+
segs = process_onehot(self.decoder(z))
|
67 |
+
lvl = lvlhcat(segs)
|
68 |
+
simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
|
69 |
+
rewardss = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
|
70 |
+
return rewardss
|
71 |
+
|
72 |
+
def __process_traj(self, rewards):
|
73 |
+
obs = []
|
74 |
+
for i in range(1, len(self.lat_vecs) + 1):
|
75 |
+
ob = get_padded_obs(self.lat_vecs[max(0, i - self.hist_len): i], self.hist_len)
|
76 |
+
obs.append(ob)
|
77 |
+
traj = [(obs[i], self.lat_vecs[i+1], rewards[i], obs[i+1]) for i in range(len(self.lat_vecs) - 1)]
|
78 |
+
return traj
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.lat_vecs.clear()
|
82 |
+
z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
|
83 |
+
self.lat_vecs.append(z0)
|
84 |
+
return self.getobs()
|
85 |
+
|
86 |
+
def getobs(self):
|
87 |
+
s = max(0, len(self.lat_vecs) - self.hist_len)
|
88 |
+
return get_padded_obs(self.lat_vecs[s:], self.hist_len, True)
|
89 |
+
|
90 |
+
def render(self, mode="human"):
|
91 |
+
pass
|
92 |
+
|
93 |
+
def generate_levels(self, agent, n=1, max_parallel=None):
|
94 |
+
if max_parallel is None:
|
95 |
+
max_parallel = min(n, 512)
|
96 |
+
levels = []
|
97 |
+
latvecs = []
|
98 |
+
obs_queues = [RingQueue(self.hist_len) for _ in range(max_parallel)]
|
99 |
+
while len(levels) < n:
|
100 |
+
veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
|
101 |
+
for queue, veclist in zip(obs_queues, veclists):
|
102 |
+
queue.clear()
|
103 |
+
init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
|
104 |
+
queue.push(init_latvec)
|
105 |
+
veclist.append(init_latvec)
|
106 |
+
for _ in range(self.eplen):
|
107 |
+
obs = np.stack([get_padded_obs(queue.to_list(), self.hist_len) for queue in obs_queues])
|
108 |
+
actions = agent.make_decision(obs)
|
109 |
+
for queue, veclist, action in zip(obs_queues, veclists, actions):
|
110 |
+
queue.push(action)
|
111 |
+
veclist.append(action)
|
112 |
+
for veclist in veclists:
|
113 |
+
latvecs.append(np.stack(veclist))
|
114 |
+
z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
|
115 |
+
lvl = lvlhcat(process_onehot(self.decoder(z)))
|
116 |
+
levels.append(lvl)
|
117 |
+
return levels, latvecs
|
118 |
+
|
119 |
+
|
120 |
+
class AsyncOlGenEnv:
|
121 |
+
def __init__(self, histlen, decoder: SAGenerator, eval_pool: AsycSimltPool, eplen: int=50, device='cuda:0'):
|
122 |
+
self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
|
123 |
+
self.decoder = decoder
|
124 |
+
self.decoder.to(device)
|
125 |
+
self.device = device
|
126 |
+
self.eval_pool = eval_pool
|
127 |
+
self.eplen = eplen
|
128 |
+
self.tid = 0
|
129 |
+
self.histlen = histlen
|
130 |
+
|
131 |
+
self.cur_vectraj = []
|
132 |
+
self.buffer = {}
|
133 |
+
|
134 |
+
def reset(self):
|
135 |
+
if len(self.cur_vectraj) > 0:
|
136 |
+
self.buffer[self.tid] = self.cur_vectraj
|
137 |
+
self.cur_vectraj = []
|
138 |
+
self.tid += 1
|
139 |
+
z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
|
140 |
+
self.cur_vectraj.append(z0)
|
141 |
+
return self.getobs()
|
142 |
+
|
143 |
+
def step(self, action):
|
144 |
+
self.cur_vectraj.append(action)
|
145 |
+
done = len(self.cur_vectraj) == (self.eplen + 1)
|
146 |
+
if done:
|
147 |
+
self.__submit_eval_task()
|
148 |
+
self.reset()
|
149 |
+
return self.getobs(), done
|
150 |
+
|
151 |
+
def getobs(self):
|
152 |
+
s = max(0, len(self.cur_vectraj) - self.histlen)
|
153 |
+
return get_padded_obs(self.cur_vectraj[s:], self.histlen, True)
|
154 |
+
|
155 |
+
def __submit_eval_task(self):
|
156 |
+
z = torch.tensor(np.stack(self.cur_vectraj).reshape([-1, nz, 1, 1]), device=self.device)
|
157 |
+
segs = process_onehot(self.decoder(torch.clamp(z, -1, 1)))
|
158 |
+
lvl = lvlhcat(segs)
|
159 |
+
args = (self.tid, str(lvl))
|
160 |
+
self.eval_pool.put('evaluate', args)
|
161 |
+
|
162 |
+
def refresh(self):
|
163 |
+
if self.eval_pool is not None:
|
164 |
+
self.eval_pool.refresh()
|
165 |
+
|
166 |
+
def rollout(self, close=False, wait=False) -> Tuple[List[Tuple], List[Dict[str, List]]]:
|
167 |
+
transitions, rewss = [], []
|
168 |
+
if close:
|
169 |
+
eval_res = self.eval_pool.close()
|
170 |
+
else:
|
171 |
+
eval_res = self.eval_pool.get(wait)
|
172 |
+
for tid, rewards in eval_res:
|
173 |
+
rewss.append(rewards)
|
174 |
+
rewsums = [sum(items) for items in zip(*rewards.values())]
|
175 |
+
vectraj = self.buffer.pop(tid)
|
176 |
+
transitions += self.__process_traj(vectraj, rewsums[-self.eplen:])
|
177 |
+
return transitions, rewss
|
178 |
+
|
179 |
+
def __process_traj(self, vectraj, rewards):
|
180 |
+
obs = []
|
181 |
+
for i in range(1, len(vectraj) + 1):
|
182 |
+
ob = get_padded_obs(vectraj[max(0, i - self.histlen): i], self.histlen)
|
183 |
+
obs.append(ob)
|
184 |
+
traj = [(obs[i], vectraj[i+1], rewards[i], obs[i+1]) for i in range(len(vectraj) - 1)]
|
185 |
+
return traj
|
186 |
+
|
187 |
+
def close(self):
|
188 |
+
res = self.rollout(True)
|
189 |
+
self.eval_pool = None
|
190 |
+
return res
|
191 |
+
|
192 |
+
def generate_levels(self, agent, n=1, max_parallel=None):
|
193 |
+
if max_parallel is None:
|
194 |
+
max_parallel = min(n, 512)
|
195 |
+
levels = []
|
196 |
+
latvecs = []
|
197 |
+
obs_queues = [RingQueue(self.histlen) for _ in range(max_parallel)]
|
198 |
+
while len(levels) < n:
|
199 |
+
veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
|
200 |
+
for queue, veclist in zip(obs_queues, veclists):
|
201 |
+
queue.clear()
|
202 |
+
init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
|
203 |
+
queue.push(init_latvec)
|
204 |
+
veclist.append(init_latvec)
|
205 |
+
for _ in range(self.eplen):
|
206 |
+
obs = np.stack([get_padded_obs(queue.to_list(), self.histlen) for queue in obs_queues])
|
207 |
+
actions = agent.make_decision(obs)
|
208 |
+
for queue, veclist, action in zip(obs_queues, veclists, actions):
|
209 |
+
queue.push(action)
|
210 |
+
veclist.append(action)
|
211 |
+
for veclist in veclists:
|
212 |
+
latvecs.append(np.stack(veclist))
|
213 |
+
z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
|
214 |
+
lvl = lvlhcat(process_onehot(self.decoder(z)))
|
215 |
+
levels.append(lvl)
|
216 |
+
return levels, latvecs
|
217 |
+
|
218 |
+
|
219 |
+
######### Adopt from https://github.com/SUSTechGameAI/MFEDRL #########
|
220 |
+
class SyncOLGenWorkerEnv(gym.Env):
|
221 |
+
def __init__(self, rfunc=None, hist_len=5, eplen=25, return_lvl=False, init_one=False, play_style='Runner'):
|
222 |
+
self.rfunc = RewardFunc() if rfunc is None else rfunc
|
223 |
+
self.mario_proxy = MarioProxy() if self.rfunc.require_simlt else None
|
224 |
+
self.action_space = gym.spaces.Box(-1, 1, (nz,))
|
225 |
+
self.hist_len = hist_len
|
226 |
+
self.observation_space = gym.spaces.Box(-1, 1, (hist_len * nz,))
|
227 |
+
self.segs = []
|
228 |
+
self.latvec_archive = RingQueue(hist_len)
|
229 |
+
self.eplen = eplen
|
230 |
+
self.counter = 0
|
231 |
+
# self.repairer = DivideConquerRepairer()
|
232 |
+
self.init_one = init_one
|
233 |
+
self.backup_latvecs = None
|
234 |
+
self.backup_strsegs = None
|
235 |
+
self.return_lvl = return_lvl
|
236 |
+
self.jagent = MarioJavaAgents.__getitem__(play_style)
|
237 |
+
self.simlt_k = 80 if play_style == 'Runner' else 320
|
238 |
+
|
239 |
+
def receive(self, **kwargs):
|
240 |
+
for key in kwargs.keys():
|
241 |
+
setattr(self, key, kwargs[key])
|
242 |
+
|
243 |
+
def step(self, data):
|
244 |
+
action, strseg = data
|
245 |
+
seg = MarioLevel(strseg)
|
246 |
+
self.latvec_archive.push(action)
|
247 |
+
|
248 |
+
self.counter += 1
|
249 |
+
self.segs.append(seg)
|
250 |
+
done = self.counter >= self.eplen
|
251 |
+
if done:
|
252 |
+
full_level = lvlhcat(self.segs)
|
253 |
+
# full_level = self.repairer.repair(full_level)
|
254 |
+
w = MarioLevel.seg_width
|
255 |
+
segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
|
256 |
+
if self.mario_proxy:
|
257 |
+
raw_simlt_res = self.mario_proxy.simulate_complete(lvlhcat(segs), self.jagent, self.simlt_k)
|
258 |
+
simlt_res = MarioProxy.get_seg_infos(raw_simlt_res)
|
259 |
+
else:
|
260 |
+
simlt_res = None
|
261 |
+
rewards = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
|
262 |
+
info = {}
|
263 |
+
total_score = 0
|
264 |
+
if self.return_lvl:
|
265 |
+
info['LevelStr'] = str(full_level)
|
266 |
+
for key in rewards:
|
267 |
+
info[f'{key}_reward_list'] = rewards[key][-self.eplen:]
|
268 |
+
info[f'{key}'] = sum(rewards[key][-self.eplen:])
|
269 |
+
total_score += info[f'{key}']
|
270 |
+
info['TotalScore'] = total_score
|
271 |
+
info['EpLength'] = self.counter
|
272 |
+
else:
|
273 |
+
info = {}
|
274 |
+
return self.__get_obs(), 0, done, info
|
275 |
+
|
276 |
+
def reset(self):
|
277 |
+
self.segs.clear()
|
278 |
+
self.latvec_archive.clear()
|
279 |
+
for latvec, strseg in zip(self.backup_latvecs, self.backup_strsegs):
|
280 |
+
self.latvec_archive.push(latvec)
|
281 |
+
self.segs.append(MarioLevel(strseg))
|
282 |
+
|
283 |
+
self.backup_latvecs, self.backup_strsegs = None, None
|
284 |
+
self.counter = 0
|
285 |
+
return self.__get_obs()
|
286 |
+
|
287 |
+
def __get_obs(self):
|
288 |
+
lack = self.hist_len - len(self.latvec_archive)
|
289 |
+
pad = [np.zeros([nz], np.float32) for _ in range(lack)]
|
290 |
+
return np.concatenate([*pad, *self.latvec_archive.to_list()])
|
291 |
+
|
292 |
+
def render(self, mode='human'):
|
293 |
+
pass
|
294 |
+
|
295 |
+
|
296 |
+
class VecOLGenEnv(SubprocVecEnv):
|
297 |
+
def __init__(
|
298 |
+
self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None, hist_len=5, eplen=50,
|
299 |
+
init_one=True, log_path=None, log_itv=-1, log_targets=None, device='cuda:0'
|
300 |
+
):
|
301 |
+
super(VecOLGenEnv, self).__init__(env_fns, start_method)
|
302 |
+
self.decoder = get_decoder(device=device)
|
303 |
+
|
304 |
+
if log_path:
|
305 |
+
self.logger = InfoCollector(log_path, log_itv, log_targets)
|
306 |
+
else:
|
307 |
+
self.logger = None
|
308 |
+
self.hist_len = hist_len
|
309 |
+
self.total_steps = 0
|
310 |
+
self.start_time = time.time()
|
311 |
+
self.eplen = eplen
|
312 |
+
self.device = device
|
313 |
+
self.init_one = init_one
|
314 |
+
self.latvec_set = np.load(getpath('smb/init_latvecs.npy'))
|
315 |
+
|
316 |
+
def step_async(self, actions: np.ndarray) -> None:
|
317 |
+
with torch.no_grad():
|
318 |
+
z = torch.tensor(actions.astype(np.float32), device=self.device).view(-1, nz, 1, 1)
|
319 |
+
segs = process_onehot(self.decoder(z))
|
320 |
+
for remote, action, seg in zip(self.remotes, actions, segs):
|
321 |
+
remote.send(("step", (action, str(seg))))
|
322 |
+
self.waiting = True
|
323 |
+
|
324 |
+
def step_wait(self) -> VecEnvStepReturn:
|
325 |
+
self.total_steps += self.num_envs
|
326 |
+
results = [remote.recv() for remote in self.remotes]
|
327 |
+
self.waiting = False
|
328 |
+
obs, rews, dones, infos = zip(*results)
|
329 |
+
|
330 |
+
envs_to_send = [i for i in range(self.num_envs) if dones[i]]
|
331 |
+
self.send_reset_data(envs_to_send)
|
332 |
+
|
333 |
+
if self.logger is not None:
|
334 |
+
for i in range(self.num_envs):
|
335 |
+
if infos[i]:
|
336 |
+
infos[i]['TotalSteps'] = self.total_steps
|
337 |
+
infos[i]['TimePassed'] = time.time() - self.start_time
|
338 |
+
self.logger.on_step(dones, infos)
|
339 |
+
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
|
340 |
+
|
341 |
+
def reset(self) -> VecEnvObs:
|
342 |
+
self.send_reset_data()
|
343 |
+
for remote in self.remotes:
|
344 |
+
remote.send(("reset", None))
|
345 |
+
obs = [remote.recv() for remote in self.remotes]
|
346 |
+
self.send_reset_data()
|
347 |
+
return _flatten_obs(obs, self.observation_space)
|
348 |
+
|
349 |
+
def send_reset_data(self, env_ids=None):
|
350 |
+
if env_ids is None:
|
351 |
+
env_ids = [*range(self.num_envs)]
|
352 |
+
target_remotes = self._get_target_remotes(env_ids)
|
353 |
+
|
354 |
+
n_inits = 1 if self.init_one else self.hist_len
|
355 |
+
# latvecs = [sample_latvec(n_inits, tensor=False) for _ in range(len(env_ids))]
|
356 |
+
|
357 |
+
latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
|
358 |
+
with torch.no_grad():
|
359 |
+
segss = [[] for _ in range(len(env_ids))]
|
360 |
+
for i in range(len(env_ids)):
|
361 |
+
z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
|
362 |
+
# print(self.decoder(z).shape)
|
363 |
+
segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
|
364 |
+
for remote, latvec, segs in zip(target_remotes, latvecs, segss):
|
365 |
+
kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
|
366 |
+
remote.send(("env_method", ('receive', [], kwargs)))
|
367 |
+
for remote in target_remotes:
|
368 |
+
remote.recv()
|
369 |
+
|
370 |
+
def close(self) -> None:
|
371 |
+
super().close()
|
372 |
+
if self.logger is not None:
|
373 |
+
self.logger.close()
|
374 |
+
|
375 |
+
|
376 |
+
def make_vec_offrew_env(
|
377 |
+
num_envs, rfunc=None, log_path=None, eplen=25, log_itv=-1, hist_len=5, init_one=True,
|
378 |
+
play_style='Runner', device='cuda:0', log_targets=None, return_lvl=False
|
379 |
+
):
|
380 |
+
return make_vec_env(
|
381 |
+
SyncOLGenWorkerEnv, n_envs=num_envs, vec_env_cls=VecOLGenEnv,
|
382 |
+
vec_env_kwargs={
|
383 |
+
'log_path': log_path,
|
384 |
+
'log_itv': log_itv,
|
385 |
+
'log_targets': log_targets,
|
386 |
+
'device': device,
|
387 |
+
'eplen': eplen,
|
388 |
+
'hist_len': hist_len,
|
389 |
+
'init_one': init_one
|
390 |
+
},
|
391 |
+
env_kwargs={
|
392 |
+
'rfunc': rfunc,
|
393 |
+
'eplen': eplen,
|
394 |
+
'return_lvl': return_lvl,
|
395 |
+
'play_style': play_style,
|
396 |
+
'hist_len': hist_len,
|
397 |
+
'init_one': init_one
|
398 |
+
}
|
399 |
+
)
|
src/env/logger.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
import numpy as np
|
5 |
+
from itertools import product
|
6 |
+
from src.smb.level import save_batch
|
7 |
+
from src.utils.filesys import getpath
|
8 |
+
|
9 |
+
############### Loggers for async environment ###############
|
10 |
+
class AsyncCsvLogger:
|
11 |
+
def __init__(self, target, rfunc, buffer_size=50):
|
12 |
+
self.rterms = tuple(term.get_name() for term in rfunc.terms)
|
13 |
+
self.cols = ('steps', *self.rterms, 'reward_sum', 'time', 'trans', 'updates', '')
|
14 |
+
self.buffer = []
|
15 |
+
self.buffer_size = buffer_size
|
16 |
+
self.ftarget = open(getpath(target), 'w', newline='')
|
17 |
+
self.writer = csv.writer(self.ftarget)
|
18 |
+
self.writer.writerow(self.cols)
|
19 |
+
|
20 |
+
def on_episode(self, **kwargs):
|
21 |
+
for rews in kwargs['rewss']:
|
22 |
+
rews_list = [sum(rews[key]) for key in self.rterms]
|
23 |
+
self.buffer.append([
|
24 |
+
kwargs['steps'], *rews_list, sum(rews_list),
|
25 |
+
kwargs['time'], kwargs['trans'], kwargs['updates']
|
26 |
+
])
|
27 |
+
self.__try_write()
|
28 |
+
if kwargs['close']:
|
29 |
+
self.close()
|
30 |
+
|
31 |
+
def __try_write(self):
|
32 |
+
if len(self.buffer) < self.buffer_size:
|
33 |
+
return
|
34 |
+
self.writer.writerows(self.buffer)
|
35 |
+
self.ftarget.flush()
|
36 |
+
self.buffer.clear()
|
37 |
+
|
38 |
+
def close(self):
|
39 |
+
self.writer.writerows(self.buffer)
|
40 |
+
self.ftarget.close()
|
41 |
+
pass
|
42 |
+
|
43 |
+
|
44 |
+
class AsyncStdLogger:
|
45 |
+
def __init__(self, rfunc, itv=2000, path=''):
|
46 |
+
self.rterms = tuple(term.get_name() for term in rfunc.terms)
|
47 |
+
self.rews = {rterm: 0. for rterm in self.rterms}
|
48 |
+
self.n = 0
|
49 |
+
self.itv = itv
|
50 |
+
self.horizon = itv
|
51 |
+
if not len(path):
|
52 |
+
self.f = None
|
53 |
+
else:
|
54 |
+
self.f = open(getpath(path), 'w')
|
55 |
+
self.last_steps = 0
|
56 |
+
self.last_trans = 0
|
57 |
+
self.last_updates = 0
|
58 |
+
self.buffer = []
|
59 |
+
pass
|
60 |
+
|
61 |
+
def on_episode(self, **kwargs):
|
62 |
+
newrews = {k: self.rews[k] for k in self.rterms}
|
63 |
+
for rews, k in product(kwargs['rewss'], self.rterms):
|
64 |
+
newrews[k] = newrews[k] + sum(rews[k])
|
65 |
+
self.rews = newrews
|
66 |
+
self.n += len(kwargs['rewss'])
|
67 |
+
if kwargs['steps'] >= self.horizon or kwargs['close']:
|
68 |
+
self.__output(**kwargs)
|
69 |
+
self.horizon += self.itv
|
70 |
+
self.rews = {k: 0 for k in self.rews.keys()}
|
71 |
+
self.n = 0
|
72 |
+
self.last_steps = kwargs['steps']
|
73 |
+
self.last_trans = kwargs['trans']
|
74 |
+
self.last_updates = kwargs['updates']
|
75 |
+
if kwargs['close'] and self.f is not None:
|
76 |
+
self.f.close()
|
77 |
+
|
78 |
+
def __output(self, **kwargs):
|
79 |
+
steps = kwargs['steps']
|
80 |
+
if kwargs['close']:
|
81 |
+
head = '-' * 20 + 'Closing rollouts' + '-' * 20
|
82 |
+
else:
|
83 |
+
head = '-' * 20 + f'Rollout of {self.last_steps}-{steps} steps' + '-' * 20
|
84 |
+
self.buffer.append(head)
|
85 |
+
rsum = 0
|
86 |
+
for t in self.rterms:
|
87 |
+
v = 0 if self.n == 0 else self.rews[t] / self.n
|
88 |
+
rsum += v
|
89 |
+
self.buffer.append(f'{t}: {v:.2f}')
|
90 |
+
self.buffer.append(f'Reward sum: {rsum: .2f}')
|
91 |
+
self.buffer.append('Time elapsed: %.1fs' % kwargs['time'])
|
92 |
+
self.buffer.append('Transitions collected: %d (%d in total)' % (kwargs['trans'] - self.last_trans, kwargs['trans']))
|
93 |
+
self.buffer.append('Number of updates: %d (%d in total)' % (kwargs['updates']- self.last_updates, kwargs['updates']))
|
94 |
+
if self.f is None:
|
95 |
+
print('\n'.join(self.buffer) + '\n')
|
96 |
+
else:
|
97 |
+
self.f.write('\n'.join(self.buffer) + '\n')
|
98 |
+
self.f.flush()
|
99 |
+
self.buffer.clear()
|
100 |
+
pass
|
101 |
+
|
102 |
+
def __reset(self):
|
103 |
+
pass
|
104 |
+
|
105 |
+
|
106 |
+
class GenResLogger:
|
107 |
+
def __init__(self, root_path, k, itv=5000):
|
108 |
+
self.k = k
|
109 |
+
self.itv = itv
|
110 |
+
self.horizon = 0
|
111 |
+
self.path = getpath(f'{root_path}/gen_log')
|
112 |
+
os.makedirs(self.path, exist_ok=True)
|
113 |
+
|
114 |
+
def on_episode(self, env, agent, steps):
|
115 |
+
if steps >= self.horizon:
|
116 |
+
lvls, vectraj = env.generate_levels(agent, self.k)
|
117 |
+
# np.save(f'{self.path}/step{steps}', vectraj)
|
118 |
+
if len(lvls):
|
119 |
+
save_batch(lvls, f'{self.path}/step{steps}')
|
120 |
+
self.horizon += self.itv
|
121 |
+
pass
|
122 |
+
pass
|
123 |
+
|
124 |
+
|
125 |
+
############### Loggers for sync environment, from https://github.com/SUSTechGameAI/MFEDRL ###############
|
126 |
+
class InfoCollector:
|
127 |
+
ignored_keys = {'episode', 'terminal_observation'}
|
128 |
+
save_itv = 1000
|
129 |
+
|
130 |
+
def __init__(self, path, log_itv=100, log_targets=None):
|
131 |
+
self.data = []
|
132 |
+
self.path = path
|
133 |
+
self.msg_itv = log_itv
|
134 |
+
self.time_before_save = InfoCollector.save_itv
|
135 |
+
self.msg_ptr = 0
|
136 |
+
self.log_targets = [] if log_targets is None else log_targets
|
137 |
+
if 'file' in log_targets:
|
138 |
+
with open(f'{self.path}/log.txt', 'w') as f:
|
139 |
+
f.write('')
|
140 |
+
self.recent_time = 0
|
141 |
+
|
142 |
+
def on_step(self, dones, infos):
|
143 |
+
for done, info in zip(dones, infos):
|
144 |
+
if done:
|
145 |
+
self.data.append({
|
146 |
+
key: val for key, val in info.items()
|
147 |
+
if key not in InfoCollector.ignored_keys and 'reward_list' not in key
|
148 |
+
})
|
149 |
+
self.time_before_save -= 1
|
150 |
+
if self.time_before_save <= 0:
|
151 |
+
with open(f'{self.path}/ep_infos.json', 'w') as f:
|
152 |
+
json.dump(self.data, f)
|
153 |
+
self.time_before_save += InfoCollector.save_itv
|
154 |
+
|
155 |
+
if self.log_targets and 0 < self.msg_itv <= (len(self.data) - self.msg_ptr):
|
156 |
+
keys = set(self.data[-1].keys()) - {'TotalSteps', 'TimePassed', 'TotalScore', 'EpLen'}
|
157 |
+
|
158 |
+
msg = '%sTotal steps: %d%s\n' % ('-' * 16, self.data[-1]['TotalSteps'], '-' * 16)
|
159 |
+
msg += 'Time passed: %ds\n' % self.data[-1]['TimePassed']
|
160 |
+
t = self.data[-1]['TimePassed'] - self.recent_time
|
161 |
+
self.recent_time = self.data[-1]['TimePassed']
|
162 |
+
f = sum(item['EpLength'] for item in self.data[self.msg_ptr:])
|
163 |
+
msg += 'fps: %.3g\n' % (f/t)
|
164 |
+
for key in keys:
|
165 |
+
values = [item[key] for item in self.data[self.msg_ptr:]]
|
166 |
+
values = np.array(values)
|
167 |
+
msg += '%s: %.2f +- %.2f\n' % (key, values.mean(), values.std())
|
168 |
+
values = [item['TotalScore'] for item in self.data[self.msg_ptr:]]
|
169 |
+
values = np.array(values)
|
170 |
+
msg += 'TotalScore: %.2f +- %.2f\n' % (values.mean(), values.std())
|
171 |
+
|
172 |
+
if 'file' in self.log_targets:
|
173 |
+
with open(f'{self.path}/log.txt', 'a') as f:
|
174 |
+
f.write(msg + '\n')
|
175 |
+
if 'std' in self.log_targets:
|
176 |
+
print(msg)
|
177 |
+
self.msg_ptr = len(self.data)
|
178 |
+
pass
|
179 |
+
|
180 |
+
def close(self):
|
181 |
+
with open(self.path, 'w') as f:
|
182 |
+
json.dump(self.data, f)
|
src/env/rfunc.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from math import ceil
|
3 |
+
from abc import abstractmethod
|
4 |
+
from src.utils.mymath import a_clip
|
5 |
+
from src.smb.level import *
|
6 |
+
|
7 |
+
defaults = {'n': 5, 'gl': 0.14, 'gg': 0.30, 'wl': 2, 'wg': 10}
|
8 |
+
|
9 |
+
|
10 |
+
class RewardFunc:
|
11 |
+
def __init__(self, *args):
|
12 |
+
self.terms = args
|
13 |
+
self.require_simlt = any(term.require_simlt for term in self.terms)
|
14 |
+
|
15 |
+
def get_rewards(self, **kwargs):
|
16 |
+
return {
|
17 |
+
term.get_name(): term.compute_rewards(**kwargs)
|
18 |
+
for term in self.terms
|
19 |
+
}
|
20 |
+
|
21 |
+
def get_n(self):
|
22 |
+
n = 1
|
23 |
+
for term in self.terms:
|
24 |
+
try:
|
25 |
+
n = max(n, term.n)
|
26 |
+
except AttributeError:
|
27 |
+
pass
|
28 |
+
return n
|
29 |
+
|
30 |
+
def __str__(self):
|
31 |
+
return 'Reward Function:\n' + ',\n'.join('\t' + str(term) for term in self.terms)
|
32 |
+
|
33 |
+
|
34 |
+
class RewardTerm:
|
35 |
+
def __init__(self, require_simlt):
|
36 |
+
self.require_simlt = require_simlt
|
37 |
+
|
38 |
+
def get_name(self):
|
39 |
+
return self.__class__.__name__
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def compute_rewards(self, **kwargs):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
class Playability(RewardTerm):
|
47 |
+
def __init__(self, magnitude=1):
|
48 |
+
super(Playability, self).__init__(True)
|
49 |
+
self.magnitude=magnitude
|
50 |
+
|
51 |
+
def compute_rewards(self, **kwargs):
|
52 |
+
simlt_res = kwargs['simlt_res']
|
53 |
+
return [0 if item['playable'] else -self.magnitude for item in simlt_res[1:]]
|
54 |
+
|
55 |
+
def __str__(self):
|
56 |
+
return f'{self.magnitude} * Playability'
|
57 |
+
|
58 |
+
|
59 |
+
class MeanDivergenceFun(RewardTerm):
|
60 |
+
def __init__(self, goal_div, n=defaults['n'], s=8):
|
61 |
+
super().__init__(False)
|
62 |
+
self.l = goal_div * 0.26 / 0.6
|
63 |
+
self.u = goal_div * 0.94 / 0.6
|
64 |
+
self.n = n
|
65 |
+
self.s = s
|
66 |
+
|
67 |
+
def compute_rewards(self, **kwargs):
|
68 |
+
segs = kwargs['segs']
|
69 |
+
rewards = []
|
70 |
+
for i in range(1, len(segs)):
|
71 |
+
seg = segs[i]
|
72 |
+
histroy = lvlhcat(segs[max(0, i - self.n): i])
|
73 |
+
k = 0
|
74 |
+
divergences = []
|
75 |
+
while k * self.s <= (min(self.n, i) - 1) * MarioLevel.seg_width:
|
76 |
+
cmp_seg = histroy[:, k * self.s: k * self.s + MarioLevel.seg_width]
|
77 |
+
# print(i, nd, cmp_seg.shape)
|
78 |
+
divergences.append(tile_pattern_js_div(seg, cmp_seg))
|
79 |
+
k += 1
|
80 |
+
mean_d = sum(divergences) / len(divergences)
|
81 |
+
if mean_d < self.l:
|
82 |
+
rewards.append(-(mean_d - self.l) ** 2)
|
83 |
+
elif mean_d > self.u:
|
84 |
+
rewards.append(-(mean_d - self.u) ** 2)
|
85 |
+
else:
|
86 |
+
rewards.append(0)
|
87 |
+
return rewards
|
88 |
+
|
89 |
+
|
90 |
+
class SACNovelty(RewardTerm):
|
91 |
+
def __init__(self, magnitude, goal_div, require_simlt, n):
|
92 |
+
super().__init__(require_simlt)
|
93 |
+
self.g = goal_div
|
94 |
+
self.magnitude = magnitude
|
95 |
+
self.n = n
|
96 |
+
|
97 |
+
def compute_rewards(self, **kwargs):
|
98 |
+
n_segs = len(kwargs['segs'])
|
99 |
+
rewards = []
|
100 |
+
for i in range(1, n_segs):
|
101 |
+
reward = 0
|
102 |
+
r_sum = 0
|
103 |
+
for k in range(1, self.n + 1):
|
104 |
+
j = i - k
|
105 |
+
if j < 0:
|
106 |
+
break
|
107 |
+
r = 1 - k / (self.n + 1)
|
108 |
+
r_sum += r
|
109 |
+
reward += a_clip(self.disim(i, j, **kwargs), self.g, r)
|
110 |
+
rewards.append(reward * self.magnitude / r_sum)
|
111 |
+
return rewards
|
112 |
+
|
113 |
+
@abstractmethod
|
114 |
+
def disim(self, i, j, **kwargs):
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
+
class LevelSACN(SACNovelty):
|
119 |
+
def __init__(self, magnitude=1, g=defaults['gl'], w=defaults['wl'], n=defaults['n']):
|
120 |
+
super(LevelSACN, self).__init__(magnitude, g, False, n)
|
121 |
+
self.w = w
|
122 |
+
|
123 |
+
def disim(self, i, j, **kwargs):
|
124 |
+
segs = kwargs['segs']
|
125 |
+
seg1, seg2 = segs[i], segs[j]
|
126 |
+
return tile_pattern_js_div(seg1, seg2, self.w)
|
127 |
+
|
128 |
+
def __str__(self):
|
129 |
+
s = f'{self.magnitude} * LevelSACN(g={self.g:.3g}, w={self.w}, n={self.n})'
|
130 |
+
return s
|
131 |
+
|
132 |
+
|
133 |
+
class GameplaySACN(SACNovelty):
|
134 |
+
def __init__(self, magnitude=1, g=defaults['gg'], w=defaults['wg'], n=defaults['n']):
|
135 |
+
super(GameplaySACN, self).__init__(magnitude, g, True, n)
|
136 |
+
self.w = w
|
137 |
+
|
138 |
+
def disim(self, i, j, **kwargs):
|
139 |
+
simlt_res = kwargs['simlt_res']
|
140 |
+
trace1, trace2 = simlt_res[i]['trace'], simlt_res[j]['trace']
|
141 |
+
return trace_div(trace1, trace2, self.w)
|
142 |
+
|
143 |
+
def __str__(self):
|
144 |
+
s = f'{self.magnitude} * GameplaySACN(g={self.g:.3g}, w={self.w}, n={self.n})'
|
145 |
+
return s
|
146 |
+
|
147 |
+
|
148 |
+
class Fun(RewardTerm):
|
149 |
+
def __init__(self, magnitude=1., num_windows=3, lb=0.26, ub=0.94, stride=8):
|
150 |
+
super().__init__(False)
|
151 |
+
self.lb, self.ub = lb, ub
|
152 |
+
self.magnitude = magnitude
|
153 |
+
self.stride = stride
|
154 |
+
self.num_windows = num_windows
|
155 |
+
self.n = ceil(num_windows * stride / MarioLevel.seg_width - 1e-8)
|
156 |
+
|
157 |
+
def compute_rewards(self, **kwargs):
|
158 |
+
n_segs = len(kwargs['segs'])
|
159 |
+
lvl = lvlhcat(kwargs['segs'])
|
160 |
+
W = MarioLevel.seg_width
|
161 |
+
rewards = []
|
162 |
+
for i in range(1, n_segs):
|
163 |
+
seg = lvl[:, W*i: W*(i+1)]
|
164 |
+
divs = []
|
165 |
+
for k in range(0, self.num_windows + 1):
|
166 |
+
s = W * i - k * self.stride
|
167 |
+
if s < 0:
|
168 |
+
break
|
169 |
+
cmp_seg = lvl[:, s:s+W]
|
170 |
+
divs.append(tile_pattern_kl_div(seg, cmp_seg))
|
171 |
+
mean_div = np.mean(divs)
|
172 |
+
rew = 0
|
173 |
+
if mean_div > self.ub:
|
174 |
+
rew = -(self.ub - mean_div) ** 2
|
175 |
+
if mean_div < self.lb:
|
176 |
+
rew = -(self.lb - mean_div) ** 2
|
177 |
+
rewards.append(rew * self.magnitude)
|
178 |
+
return rewards
|
179 |
+
|
180 |
+
def __str__(self):
|
181 |
+
s = f'{self.magnitude} * Fun(lb={self.lb:.2f}, ub={self.ub:.2f}, n={self.num_windows}, stride={self.stride})'
|
182 |
+
return s
|
183 |
+
|
184 |
+
|
185 |
+
class HistoricalDeviation(RewardTerm):
|
186 |
+
def __init__(self, magnitude=1., m=3, n=10):
|
187 |
+
super().__init__(False)
|
188 |
+
self.magnitude = magnitude
|
189 |
+
self.m = m
|
190 |
+
self.n = n
|
191 |
+
|
192 |
+
def compute_rewards(self, **kwargs):
|
193 |
+
segs = kwargs['segs']
|
194 |
+
n_segs = len(kwargs['segs'])
|
195 |
+
rewards = []
|
196 |
+
for i in range(1, n_segs):
|
197 |
+
divs = []
|
198 |
+
for k in range(1, self.n+1):
|
199 |
+
j = i - k
|
200 |
+
if j < 0:
|
201 |
+
break
|
202 |
+
divs.append(tile_pattern_kl_div(segs[i], segs[j]))
|
203 |
+
divs.sort()
|
204 |
+
m = min(i, self.m)
|
205 |
+
rew = np.mean(divs[:m])
|
206 |
+
rewards.append(rew * self.magnitude)
|
207 |
+
return rewards
|
208 |
+
|
209 |
+
def __str__(self):
|
210 |
+
return f'{self.magnitude} * HistoricalDeviation(m={self.m}, n={self.n})'
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == '__main__':
|
214 |
+
# print(type(ceil(0.2)))
|
215 |
+
# arr = [1., 3., 2.]
|
216 |
+
# arr.sort()
|
217 |
+
# print(arr)
|
218 |
+
rfunc = HistoricalDeviation()
|
219 |
+
|
src/env/rfuncs.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.env.rfunc import *
|
2 |
+
|
3 |
+
default = lambda : RewardFunc(LevelSACN(), GameplaySACN(), Playability())
|
4 |
+
fp = lambda : RewardFunc(Fun(200), Playability(10))
|
5 |
+
hp = lambda : RewardFunc(HistoricalDeviation(1), Playability(1))
|
6 |
+
fhp = lambda : RewardFunc(Fun(30, num_windows=21), HistoricalDeviation(3), Playability(3))
|
7 |
+
lp = lambda : RewardFunc(LevelSACN(2.), Playability(2.))
|
8 |
+
gp = lambda : RewardFunc(GameplaySACN(2.), Playability(2.))
|
9 |
+
lgp = lambda : RewardFunc(LevelSACN(), GameplaySACN(), Playability())
|