NCERL-Diverse-PCG / src /env /environments.py
baiyanlali-zhao's picture
添加注释
3582c8a
raw
history blame
15.5 kB
import random
import time
import gym
import numpy as np
from typing import Tuple, List, Dict, Callable, Optional
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvObs
from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs
from src.env.logger import InfoCollector
from src.env.rfunc import RewardFunc
from src.smb.asyncsimlt import AsycSimltPool
from src.smb.level import lvlhcat
from src.gan.gans import SAGenerator
from src.gan.gankits import *
from src.smb.proxy import MarioProxy, MarioJavaAgents
from src.utils.datastruct import RingQueue
def get_padded_obs(vecs, histlen, add_batch_dim=False):
if len(vecs) < histlen:
lack = histlen - len(vecs)
pad = [np.zeros([nz], np.float32) for _ in range(lack)]
obs = np.concatenate(pad + vecs)
else:
obs = np.concatenate(vecs)
if add_batch_dim:
obs = np.reshape(obs, [1, -1])
return obs
class SingleProcessOLGenEnv(gym.Env):
def __init__(self, rfunc, decoder: SAGenerator, eplen: int=50, device='cuda:0'):
self.rfunc = rfunc
self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
self.decoder = decoder
self.decoder.to(device)
self.hist_len = self.rfunc.get_n()
self.eplen = eplen
self.device = device
self.action_space = gym.spaces.Box(-1, 1, (nz,))
self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,))
self.lat_vecs = []
self.simulator = MarioProxy()
pass
def step(self, action):
self.lat_vecs.append(action)
done = len(self.lat_vecs) == (self.eplen + 1)
info = {}
if done:
rewss = self.__evalute()
info['rewss'] = rewss
rewsums = [sum(items) for items in zip(*rewss.values())]
info['transitions'] = self.__process_traj(rewsums[-self.eplen:])
self.reset()
return self.getobs(), 0, done, info
def __evalute(self):
z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float)
segs = process_onehot(self.decoder(z))
lvl = lvlhcat(segs)
simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl))
rewardss = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
return rewardss
def __process_traj(self, rewards):
obs = []
for i in range(1, len(self.lat_vecs) + 1):
ob = get_padded_obs(self.lat_vecs[max(0, i - self.hist_len): i], self.hist_len)
obs.append(ob)
traj = [(obs[i], self.lat_vecs[i+1], rewards[i], obs[i+1]) for i in range(len(self.lat_vecs) - 1)]
return traj
def reset(self):
self.lat_vecs.clear()
z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
self.lat_vecs.append(z0)
return self.getobs()
def getobs(self):
s = max(0, len(self.lat_vecs) - self.hist_len)
return get_padded_obs(self.lat_vecs[s:], self.hist_len, True)
def render(self, mode="human"):
pass
def generate_levels(self, agent, n=1, max_parallel=None):
if max_parallel is None:
max_parallel = min(n, 512)
levels = []
latvecs = []
obs_queues = [RingQueue(self.hist_len) for _ in range(max_parallel)]
while len(levels) < n:
veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
for queue, veclist in zip(obs_queues, veclists):
queue.clear()
init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
queue.push(init_latvec)
veclist.append(init_latvec)
for _ in range(self.eplen):
obs = np.stack([get_padded_obs(queue.to_list(), self.hist_len) for queue in obs_queues])
actions = agent.make_decision(obs)
for queue, veclist, action in zip(obs_queues, veclists, actions):
queue.push(action)
veclist.append(action)
for veclist in veclists:
latvecs.append(np.stack(veclist))
z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
lvl = lvlhcat(process_onehot(self.decoder(z)))
levels.append(lvl)
return levels, latvecs
class AsyncOlGenEnv:
def __init__(self, histlen, decoder: SAGenerator, eval_pool: AsycSimltPool, eplen: int=50, device='cuda:0'):
self.initvec_set = np.load(getpath('smb/init_latvecs.npy'))
self.decoder = decoder
self.decoder.to(device)
self.device = device
# mario simulator 在eval_pool里面
self.eval_pool = eval_pool
self.eplen = eplen
self.tid = 0
self.histlen = histlen
self.cur_vectraj = []
self.buffer = {}
def reset(self):
if len(self.cur_vectraj) > 0:
self.buffer[self.tid] = self.cur_vectraj
self.cur_vectraj = []
self.tid += 1
z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))]
self.cur_vectraj.append(z0)
return self.getobs()
def step(self, action):
self.cur_vectraj.append(action)
done = len(self.cur_vectraj) == (self.eplen + 1)
if done:
self.__submit_eval_task()
self.reset()
return self.getobs(), done
def getobs(self):
s = max(0, len(self.cur_vectraj) - self.histlen)
return get_padded_obs(self.cur_vectraj[s:], self.histlen, True)
def __submit_eval_task(self):
z = torch.tensor(np.stack(self.cur_vectraj).reshape([-1, nz, 1, 1]), device=self.device)
segs = process_onehot(self.decoder(torch.clamp(z, -1, 1)))
lvl = lvlhcat(segs)
args = (self.tid, str(lvl))
self.eval_pool.put('evaluate', args)
def refresh(self):
if self.eval_pool is not None:
self.eval_pool.refresh()
def rollout(self, close=False, wait=False) -> Tuple[List[Tuple], List[Dict[str, List]]]:
transitions, rewss = [], []
if close:
eval_res = self.eval_pool.close()
else:
eval_res = self.eval_pool.get(wait)
for tid, rewards in eval_res:
rewss.append(rewards)
rewsums = [sum(items) for items in zip(*rewards.values())]
vectraj = self.buffer.pop(tid)
transitions += self.__process_traj(vectraj, rewsums[-self.eplen:])
return transitions, rewss
def __process_traj(self, vectraj, rewards):
obs = []
for i in range(1, len(vectraj) + 1):
ob = get_padded_obs(vectraj[max(0, i - self.histlen): i], self.histlen)
obs.append(ob)
traj = [(obs[i], vectraj[i+1], rewards[i], obs[i+1]) for i in range(len(vectraj) - 1)]
return traj
def close(self):
res = self.rollout(True)
self.eval_pool = None
return res
def generate_levels(self, agent, n=1, max_parallel=None):
if max_parallel is None:
max_parallel = min(n, 512)
levels = []
latvecs = []
obs_queues = [RingQueue(self.histlen) for _ in range(max_parallel)]
while len(levels) < n:
veclists = [[] for _ in range(min(max_parallel, n - len(levels)))]
for queue, veclist in zip(obs_queues, veclists):
queue.clear()
init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))]
queue.push(init_latvec)
veclist.append(init_latvec)
for _ in range(self.eplen):
obs = np.stack([get_padded_obs(queue.to_list(), self.histlen) for queue in obs_queues])
actions = agent.make_decision(obs)
for queue, veclist, action in zip(obs_queues, veclists, actions):
queue.push(action)
veclist.append(action)
for veclist in veclists:
latvecs.append(np.stack(veclist))
z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1)
lvl = lvlhcat(process_onehot(self.decoder(z)))
levels.append(lvl)
return levels, latvecs
######### Adopt from https://github.com/SUSTechGameAI/MFEDRL #########
class SyncOLGenWorkerEnv(gym.Env):
def __init__(self, rfunc=None, hist_len=5, eplen=25, return_lvl=False, init_one=False, play_style='Runner'):
self.rfunc = RewardFunc() if rfunc is None else rfunc
self.mario_proxy = MarioProxy() if self.rfunc.require_simlt else None
self.action_space = gym.spaces.Box(-1, 1, (nz,))
self.hist_len = hist_len
self.observation_space = gym.spaces.Box(-1, 1, (hist_len * nz,))
self.segs = []
self.latvec_archive = RingQueue(hist_len)
self.eplen = eplen
self.counter = 0
# self.repairer = DivideConquerRepairer()
self.init_one = init_one
self.backup_latvecs = None
self.backup_strsegs = None
self.return_lvl = return_lvl
self.jagent = MarioJavaAgents.__getitem__(play_style)
self.simlt_k = 80 if play_style == 'Runner' else 320
def receive(self, **kwargs):
for key in kwargs.keys():
setattr(self, key, kwargs[key])
def step(self, data):
action, strseg = data
seg = MarioLevel(strseg)
self.latvec_archive.push(action)
self.counter += 1
self.segs.append(seg)
done = self.counter >= self.eplen
if done:
full_level = lvlhcat(self.segs)
w = MarioLevel.seg_width
segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)]
if self.mario_proxy:
raw_simlt_res = self.mario_proxy.simulate_complete(lvlhcat(segs), self.jagent, self.simlt_k)
simlt_res = MarioProxy.get_seg_infos(raw_simlt_res)
else:
simlt_res = None
rewards = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res)
info = {}
total_score = 0
if self.return_lvl:
info['LevelStr'] = str(full_level)
for key in rewards:
info[f'{key}_reward_list'] = rewards[key][-self.eplen:]
info[f'{key}'] = sum(rewards[key][-self.eplen:])
total_score += info[f'{key}']
info['TotalScore'] = total_score
info['EpLength'] = self.counter
else:
info = {}
return self.__get_obs(), 0, done, info
def reset(self):
self.segs.clear()
self.latvec_archive.clear()
for latvec, strseg in zip(self.backup_latvecs, self.backup_strsegs):
self.latvec_archive.push(latvec)
self.segs.append(MarioLevel(strseg))
self.backup_latvecs, self.backup_strsegs = None, None
self.counter = 0
return self.__get_obs()
def __get_obs(self):
lack = self.hist_len - len(self.latvec_archive)
pad = [np.zeros([nz], np.float32) for _ in range(lack)]
return np.concatenate([*pad, *self.latvec_archive.to_list()])
def render(self, mode='human'):
pass
class VecOLGenEnv(SubprocVecEnv):
def __init__(
self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None, hist_len=5, eplen=50,
init_one=True, log_path=None, log_itv=-1, log_targets=None, device='cuda:0'
):
super(VecOLGenEnv, self).__init__(env_fns, start_method)
self.decoder = get_decoder(device=device)
if log_path:
self.logger = InfoCollector(log_path, log_itv, log_targets)
else:
self.logger = None
self.hist_len = hist_len
self.total_steps = 0
self.start_time = time.time()
self.eplen = eplen
self.device = device
self.init_one = init_one
self.latvec_set = np.load(getpath('smb/init_latvecs.npy'))
def step_async(self, actions: np.ndarray) -> None:
with torch.no_grad():
z = torch.tensor(actions.astype(np.float32), device=self.device).view(-1, nz, 1, 1)
segs = process_onehot(self.decoder(z))
for remote, action, seg in zip(self.remotes, actions, segs):
remote.send(("step", (action, str(seg))))
self.waiting = True
def step_wait(self) -> VecEnvStepReturn:
self.total_steps += self.num_envs
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
envs_to_send = [i for i in range(self.num_envs) if dones[i]]
self.send_reset_data(envs_to_send)
if self.logger is not None:
for i in range(self.num_envs):
if infos[i]:
infos[i]['TotalSteps'] = self.total_steps
infos[i]['TimePassed'] = time.time() - self.start_time
self.logger.on_step(dones, infos)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
def reset(self) -> VecEnvObs:
self.send_reset_data()
for remote in self.remotes:
remote.send(("reset", None))
obs = [remote.recv() for remote in self.remotes]
self.send_reset_data()
return _flatten_obs(obs, self.observation_space)
def send_reset_data(self, env_ids=None):
if env_ids is None:
env_ids = [*range(self.num_envs)]
target_remotes = self._get_target_remotes(env_ids)
n_inits = 1 if self.init_one else self.hist_len
latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))]
with torch.no_grad():
segss = [[] for _ in range(len(env_ids))]
for i in range(len(env_ids)):
z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device)
segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z))
for remote, latvec, segs in zip(target_remotes, latvecs, segss):
kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]}
remote.send(("env_method", ('receive', [], kwargs)))
for remote in target_remotes:
remote.recv()
def close(self) -> None:
super().close()
if self.logger is not None:
self.logger.close()
def make_vec_offrew_env(
num_envs, rfunc=None, log_path=None, eplen=25, log_itv=-1, hist_len=5, init_one=True,
play_style='Runner', device='cuda:0', log_targets=None, return_lvl=False
):
return make_vec_env(
SyncOLGenWorkerEnv, n_envs=num_envs, vec_env_cls=VecOLGenEnv,
vec_env_kwargs={
'log_path': log_path,
'log_itv': log_itv,
'log_targets': log_targets,
'device': device,
'eplen': eplen,
'hist_len': hist_len,
'init_one': init_one
},
env_kwargs={
'rfunc': rfunc,
'eplen': eplen,
'return_lvl': return_lvl,
'play_style': play_style,
'hist_len': hist_len,
'init_one': init_one
}
)