File size: 889 Bytes
850b0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a7833f
 
850b0e4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
import torch
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

mario_lm = MarioLM()

device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"

def update(prompt, progress=gr.Progress(track_tqdm=True)):
    prompts = [prompt]
    generated_level = mario_lm.sample(
        prompts=prompts,
        num_steps=1399,
        temperature=2.0,
        use_tqdm=True
    )
    img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
    return img   

with gr.Blocks() as demo:
    prompt = gr.Textbox(label="Enter your MarioGPT prompt")
    level_image = gr.Image()
    btn = gr.Button("Generate level")
    btn.click(fn=update, inputs=prompt, outputs=level_image)
    pass
demo.launch()