Spaces:
Sleeping
Sleeping
File size: 3,262 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from itertools import chain
from math import ceil
import torch
import random
import numpy as np
from typing import List
from src.gan.gans import nz
from src.gan.gankits import process_onehot, get_decoder
from src.smb.level import MarioLevel, lvlhcat
from src.utils.datastruct import RingQueue
from src.utils.filesys import getpath
class OnlineGenerator:
def __init__(self, policy, decoder=None, g_device='cuda:0'):
self.init_vecs = np.load(getpath('smb/init_latvecs.npy'))
self.policy = policy
self.decoder = get_decoder() if decoder is None else decoder
self.decoder.to(g_device)
self.g_device = g_device
self.obs_buffer = RingQueue(policy.n)
self.re_init()
self.init_seg = None
def re_init(self, condition=None):
for _ in range(self.policy.n):
self.obs_buffer.push(np.zeros([nz]))
if condition is not None:
for item in condition:
self.obs_buffer.push(condition)
else:
latvec = random.choice(self.init_vecs)
self.obs_buffer.push(latvec)
init_onehot = torch.tensor(self.obs_buffer.rear(), device=self.g_device).view(1, -1, 1, 1)
self.init_seg = process_onehot(self.decoder(init_onehot))
def step(self):
obs = np.concatenate(self.obs_buffer.to_list(), axis=-1)
latvec = self.policy.step(obs)
self.obs_buffer.push(latvec)
z = torch.tensor(latvec, device=self.g_device).view(-1, nz, 1, 1)
seg = process_onehot(self.decoder(z))
return seg
def forward(self, l) -> List[MarioLevel]:
self.re_init()
self.policy.reset()
return [self.init_seg, *(self.step() for _ in range(l))]
def generate(self, n, l):
return [lvlhcat(self.forward(l)) for _ in range(n)]
class VecOnlineGenerator(OnlineGenerator):
def __init__(self, policy, decoder=None, g_device='cuda:0', vec_num=50):
self.vec_num = vec_num
super().__init__(policy, decoder, g_device)
def re_init(self, condition=None):
for _ in range(self.policy.n):
self.obs_buffer.push(np.zeros([self.vec_num, nz]))
if condition is not None:
self.obs_buffer.push(condition)
else:
latvecs = self.init_vecs[random.sample(range(len(self.init_vecs)), self.vec_num)]
self.obs_buffer.push(latvecs)
init_onehot = torch.tensor(self.obs_buffer.rear(), device=self.g_device).view(-1, nz, 1, 1)
self.init_seg = process_onehot(self.decoder(init_onehot))
# for lvl in self.init_seg[:2]:
# print(lvl)
def forward(self, l, rand_init=True):
if rand_init: self.re_init()
self.policy.reset()
lvls = [[item] for item in self.init_seg]
for _ in range(l):
for lvl, seg in zip(lvls, self.step()):
lvl.append(seg)
return lvls
def generate(self, n, l, rand_init=True):
batchs = [[lvlhcat(item) for item in self.forward(l, rand_init)] for _ in range(ceil(n / self.vec_num))]
res = list(chain(*batchs))
return res[:n]
if __name__ == '__main__':
a = np.random.rand(10, 3)
print(a)
print(a[random.sample(range(len(a)), 2)])
pass
|