Spaces:
Sleeping
Sleeping
import random | |
import gradio as gr | |
import os | |
from os import path | |
import sys | |
sys.path.append(path.dirname(path.abspath(__file__))) | |
from src.olgen.ol_generator import VecOnlineGenerator | |
# from src.olgen.olg_game import MarioOnlineGenGame | |
from src.olgen.olg_policy import RLGenPolicy | |
from src.smb.level import save_batch | |
from src.utils.filesys import getpath | |
from src.utils.img import make_img_sheet | |
import torch | |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
def generate_and_play(): | |
path = 'models/example_policy' | |
# Generate with example policy model | |
N, L = 8, 10 | |
plc = RLGenPolicy.from_path(path, device) | |
generator = VecOnlineGenerator(plc, g_device=device) | |
fd, _ = os.path.split(getpath(path)) | |
os.makedirs(fd, exist_ok=True) | |
lvls = generator.generate(N, L) | |
# save_batch(lvls, f'{path}/samples.lvls') | |
imgs = [lvl.to_img() for lvl in lvls] | |
return imgs | |
# make_img_sheet(imgs, 1, save_path=f'{path}/samples.png') | |
# # Play with the example policy model | |
# game = MarioOnlineGenGame(path) | |
# game.play() | |
with gr.Blocks(title="NCERL Demo") as demo: | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
, columns=[3], rows=[1], object_fit="contain", height="auto") | |
btn = gr.Button("Generate levels", scale=0) | |
btn.click(generate_and_play, None, gallery) | |
if __name__ == "__main__": | |
demo.launch() |