File size: 4,015 Bytes
c8c7b71
850b0e4
 
c8c7b71
850b0e4
 
 
 
 
c8c7b71
 
 
 
850b0e4
c8c7b71
850b0e4
 
 
 
c8c7b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ec2ec6
c8c7b71
 
 
 
 
 
 
 
 
 
 
1ec2ec6
c8c7b71
1ec2ec6
 
 
850b0e4
 
 
1ec2ec6
 
850b0e4
 
c8c7b71
850b0e4
c8c7b71
 
 
 
 
 
850b0e4
 
c8c7b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226c5b4
 
850b0e4
c8c7b71
 
 
 
 
d59d1e6
1ec2ec6
 
 
 
 
 
 
c8c7b71
 
141b1fb
d59d1e6
850b0e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

import gradio as gr
import torch
import uuid
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

import os
import subprocess

from pyngrok import ngrok

mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"

subprocess.Popen(["python3","-m","http.server","7861"])
ngrok.set_auth_token(os.environ.get('NGROK_TOKEN'))
http_tunnel = ngrok.connect(7861,bind_tls=True)

def make_html_file(generated_level):
    level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
    unique_id = uuid.uuid1()
    with open(f"demo-{unique_id}.html", 'w', encoding='utf-8') as f:
        f.write(f'''<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="utf-8">
    <title>Mario Game</title>
    <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>

<body>
</body>
<script>
    cheerpjInit().then(function () {{
        cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
    }});
    cheerpjCreateDisplay(512, 500);
    cheerpjRunJar("/app/mario.jar");
</script>
</html>''')
    return f"demo-{unique_id}.html"

def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size = 1399, prompt = ""):
    if prompt == "":
        prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
    print(f"Using prompt: {prompt}")
    prompts = [prompt]
    generated_level = mario_lm.sample(
        prompts=prompts,
        num_steps=level_size,
        temperature=temperature,
        use_tqdm=True
    )
    filename = make_html_file(generated_level)
    img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
    
    gradio_html = f'''<div style="border: 2px solid;">
        <iframe width=512 height=512 style="margin: 0 auto" src="{http_tunnel.public_url}/{filename}"></iframe>
        <p style="text-align:center">Press the arrow keys to move. Press <code>s</code> to jump and <code>a</code> to shoot flowers</p>
    </div>'''
    return [img, gradio_html]

with gr.Blocks() as demo:
    gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
    [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
    ''')
    with gr.Tabs():
        with gr.TabItem("Type prompt"):
            text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
        with gr.TabItem("Compose prompt"):
            with gr.Row():
                pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
                enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
            with gr.Row():
                blocks = gr.Radio(["little", "some", "many"], label="blocks")
                elevation = gr.Radio(["low", "high"], label="elevation")
            
    with gr.Accordion(label="Advanced settings", open=False):
        temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
        level_size = gr.Number(value=1399, precision=0, label="level_size")
    
    btn = gr.Button("Generate level")
    with gr.Row():
        with gr.Box():
            level_play = gr.HTML()    
        level_image = gr.Image()
    btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
    gr.Examples(
        examples=[
            ["many", "many", "some", "high"],
            ["no", "some", "many", "high", 2.0],
            ["many", "many", "little", "low", 2.0],
            ["no", "no", "many", "high", 2.4],
        ],
        inputs=[pipes, enemies, blocks, elevation],
        outputs=[level_image, level_play],
        fn=generate,
        cache_examples=True,
    )
demo.launch()