mariogpt / app.py
multimodalart's picture
Update app.py
d59d1e6
raw
history blame
1.21 kB
import gradio as gr
import torch
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
def update(prompt, progress=gr.Progress(track_tqdm=True)):
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=1399,
temperature=2.0,
use_tqdm=True
)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
return img
with gr.Blocks() as demo:
prompt = gr.Textbox(label="Enter your MarioGPT prompt")
level_image = gr.Image()
btn = gr.Button("Generate level")
btn.click(fn=update, inputs=prompt, outputs=level_image)
gr.Examples(
examples=["many pipes, many enemies, some blocks, high elevation", "little pipes, little enemies, many blocks, high elevation", "many pipes, some enemies", "no pipes, no enemies, many blocks"],
inputs=prompt,
outputs=level_image,
fn=update,
cache_examples=True,
)
demo.launch()