Spaces:
Sleeping
Sleeping
import torch as t | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
from src.gan.adversarial_train import get_gan_train_data | |
from src.utils.filesys import getpath | |
# filepath = Path(__file__).parent.resolve() | |
# DATA_PATH = os.path.join(filepath, "levels", "ground", "unique_onehot.npz") | |
DATA_PATH = getpath('smb/levels') | |
class MarioDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
sample = self.data[idx] | |
x = t.tensor(sample, dtype=t.float32) | |
return x | |
# def load_data(file_path): | |
# # data = np.load(file_path) | |
# # levels = data['levels'] | |
# levels = traverse_level_files(DATA_PATH) | |
# onehots = [] | |
# for lvl in levels: | |
# num_lvl = lvl.to_num_arr() | |
# _, length = num_lvl.shape | |
# for s in range(length - W): | |
# seg = num_lvl[:, s: s+W] | |
# onehot = np.zeros([MarioLevel.n_types, H, W]) | |
# xs = [seg[i, j] for i, j in product(range(H), range(W))] | |
# ys = [k // W for k in range(H * W)] | |
# zs = [k % W for k in range(H * W)] | |
# onehot[xs, ys, zs] = 1 | |
# data.append(onehot) | |
# | |
# return [lvl.to_onehot() for lvl in levels] | |
def create_dataloader(batch_size=32, shuffle=True, num_workers=0): | |
# data = load_data(file_path) | |
data = get_gan_train_data() | |
dataset = MarioDataset(data) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | |
return dataloader | |
if __name__ == '__main__': | |
# mario_dataloader = create_dataloader(DATA_PATH, batch_size=64, shuffle=True) | |
# # collect the first batch from mario_dataloader | |
# first_batch = next((iter(mario_dataloader))) | |
# print(first_batch.shape) | |
# # plot all the levels in the first batch | |
# fig, axes = plt.subplots(8, 8, figsize=(10, 10)) | |
# for i, ax in enumerate(axes.flatten()): | |
# level = np.argmax(first_batch[i], axis=0).numpy() | |
# image = get_img_from_level(level) | |
# ax.imshow(255 * np.ones_like(image)) # White background | |
# ax.imshow(image) | |
# ax.axis("off") | |
# plt.show() | |
data_ = create_dataloader(batch_size=64, shuffle=True).dataset.data | |
data_ = np.argmax(data_, axis=1) | |
for item in data_: | |
print(item.shape) | |
# count the occurrence of each number in data | |
# for i in range(MarioLevel.n_types): | |
# print(f"{i}: {np.count_nonzero(data_ == i)}") | |
print([np.count_nonzero(data_ == i) for i in range(21)]) | |
pass |