Spaces:
Sleeping
Sleeping
from itertools import chain | |
from math import ceil | |
import torch | |
import random | |
import numpy as np | |
from typing import List | |
from src.gan.gans import nz | |
from src.gan.gankits import process_onehot, get_decoder | |
from src.smb.level import MarioLevel, lvlhcat | |
from src.utils.datastruct import RingQueue | |
from src.utils.filesys import getpath | |
class OnlineGenerator: | |
def __init__(self, policy, decoder=None, g_device='cuda:0'): | |
self.init_vecs = np.load(getpath('smb/init_latvecs.npy')) | |
self.policy = policy | |
self.decoder = get_decoder() if decoder is None else decoder | |
self.decoder.to(g_device) | |
self.g_device = g_device | |
self.obs_buffer = RingQueue(policy.n) | |
self.re_init() | |
self.init_seg = None | |
def re_init(self, condition=None): | |
for _ in range(self.policy.n): | |
self.obs_buffer.push(np.zeros([nz])) | |
if condition is not None: | |
for item in condition: | |
self.obs_buffer.push(condition) | |
else: | |
latvec = random.choice(self.init_vecs) | |
self.obs_buffer.push(latvec) | |
init_onehot = torch.tensor(self.obs_buffer.rear(), device=self.g_device).view(1, -1, 1, 1) | |
self.init_seg = process_onehot(self.decoder(init_onehot)) | |
def step(self): | |
obs = np.concatenate(self.obs_buffer.to_list(), axis=-1) | |
latvec = self.policy.step(obs) | |
self.obs_buffer.push(latvec) | |
z = torch.tensor(latvec, device=self.g_device).view(-1, nz, 1, 1) | |
seg = process_onehot(self.decoder(z)) | |
return seg | |
def forward(self, l) -> List[MarioLevel]: | |
self.re_init() | |
self.policy.reset() | |
return [self.init_seg, *(self.step() for _ in range(l))] | |
def generate(self, n, l): | |
return [lvlhcat(self.forward(l)) for _ in range(n)] | |
class VecOnlineGenerator(OnlineGenerator): | |
def __init__(self, policy, decoder=None, g_device='cuda:0', vec_num=50): | |
self.vec_num = vec_num | |
super().__init__(policy, decoder, g_device) | |
def re_init(self, condition=None): | |
for _ in range(self.policy.n): | |
self.obs_buffer.push(np.zeros([self.vec_num, nz])) | |
if condition is not None: | |
self.obs_buffer.push(condition) | |
else: | |
latvecs = self.init_vecs[random.sample(range(len(self.init_vecs)), self.vec_num)] | |
self.obs_buffer.push(latvecs) | |
init_onehot = torch.tensor(self.obs_buffer.rear(), device=self.g_device).view(-1, nz, 1, 1) | |
self.init_seg = process_onehot(self.decoder(init_onehot)) | |
# for lvl in self.init_seg[:2]: | |
# print(lvl) | |
def forward(self, l, rand_init=True): | |
if rand_init: self.re_init() | |
self.policy.reset() | |
lvls = [[item] for item in self.init_seg] | |
for _ in range(l): | |
for lvl, seg in zip(lvls, self.step()): | |
lvl.append(seg) | |
return lvls | |
def generate(self, n, l, rand_init=True): | |
batchs = [[lvlhcat(item) for item in self.forward(l, rand_init)] for _ in range(ceil(n / self.vec_num))] | |
res = list(chain(*batchs)) | |
return res[:n] | |
if __name__ == '__main__': | |
a = np.random.rand(10, 3) | |
print(a) | |
print(a[random.sample(range(len(a)), 2)]) | |
pass | |