Spaces:
Sleeping
Sleeping
import torch | |
from src.smb.level import MarioLevel | |
from src.gan.gans import nz | |
from src.utils.filesys import getpath | |
# 采样噪声 | |
def sample_latvec(n=1, device='cpu', distribuion='uniform'): | |
if distribuion == 'uniform': | |
return torch.rand(n, nz, 1, 1, device=device) * 2 - 1 | |
elif distribuion == 'normal': | |
return torch.randn(n, nz, 1, 1, device=device) | |
else: | |
raise TypeError(f'unknow noise distribution: {distribuion}') | |
# 处理onehot数组 | |
def process_onehot(raw_tensor_onehot): | |
H, W = MarioLevel.height, MarioLevel.seg_width | |
res = [] | |
for single in raw_tensor_onehot: | |
data = single[:, :H, :W].detach().cpu().numpy() | |
lvl = MarioLevel.from_one_hot_arr(data) | |
res.append(lvl) | |
return res if len(res) > 1 else res[0] | |
def get_decoder(path='models/decoder.pth', device='cpu'): | |
decoder = torch.load(getpath(path), map_location=device) | |
decoder.requires_grad_(False) | |
decoder.eval() | |
return decoder | |