Spaces:
Runtime error
Runtime error
Commit
•
ee2e0b7
1
Parent(s):
6845a5c
Update app.py
Browse files
app.py
CHANGED
@@ -7,25 +7,24 @@ from mario_gpt.prompter import Prompter
|
|
7 |
from mario_gpt.lm import MarioLM
|
8 |
from mario_gpt.utils import view_level, convert_level_to_png
|
9 |
|
10 |
-
import
|
11 |
-
import
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
mario_lm = MarioLM()
|
16 |
device = torch.device('cuda')
|
17 |
mario_lm = mario_lm.to(device)
|
18 |
TILE_DIR = "data/tiles"
|
19 |
|
20 |
-
|
21 |
-
ngrok.set_auth_token(os.environ.get('NGROK_TOKEN'))
|
22 |
-
http_tunnel = ngrok.connect(7861,bind_tls=True)
|
23 |
|
24 |
def make_html_file(generated_level):
|
25 |
level_text = f"""{'''
|
26 |
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
|
27 |
unique_id = uuid.uuid1()
|
28 |
-
with open(f"demo-{unique_id}.html", 'w', encoding='utf-8') as f:
|
29 |
f.write(f'''<!DOCTYPE html>
|
30 |
<html lang="en">
|
31 |
|
@@ -42,7 +41,7 @@ def make_html_file(generated_level):
|
|
42 |
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
|
43 |
}});
|
44 |
cheerpjCreateDisplay(512, 500);
|
45 |
-
cheerpjRunJar("/app/mario.jar");
|
46 |
</script>
|
47 |
</html>''')
|
48 |
return f"demo-{unique_id}.html"
|
@@ -61,9 +60,9 @@ def generate(pipes, enemies, blocks, elevation, temperature = 2.0, level_size =
|
|
61 |
filename = make_html_file(generated_level)
|
62 |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
|
63 |
|
64 |
-
gradio_html = f'''<div
|
65 |
-
<iframe width=512 height=512 style="margin: 0 auto" src="
|
66 |
-
<p style="text-align:center">Press the arrow keys to move. Press <code>s</code> to jump and <code>
|
67 |
</div>'''
|
68 |
return [img, gradio_html]
|
69 |
|
@@ -72,16 +71,16 @@ with gr.Blocks() as demo:
|
|
72 |
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
|
73 |
''')
|
74 |
with gr.Tabs():
|
75 |
-
with gr.TabItem("Type prompt"):
|
76 |
-
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
|
77 |
with gr.TabItem("Compose prompt"):
|
78 |
with gr.Row():
|
79 |
-
pipes = gr.Radio(["no", "little", "some", "many"], label="pipes")
|
80 |
-
enemies = gr.Radio(["no", "little", "some", "many"], label="enemies")
|
81 |
with gr.Row():
|
82 |
-
blocks = gr.Radio(["little", "some", "many"], label="blocks")
|
83 |
-
elevation = gr.Radio(["low", "high"], label="
|
84 |
-
|
|
|
|
|
85 |
with gr.Accordion(label="Advanced settings", open=False):
|
86 |
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
|
87 |
level_size = gr.Number(value=1399, precision=0, label="level_size")
|
@@ -104,4 +103,7 @@ with gr.Blocks() as demo:
|
|
104 |
fn=generate,
|
105 |
cache_examples=True,
|
106 |
)
|
107 |
-
|
|
|
|
|
|
|
|
7 |
from mario_gpt.lm import MarioLM
|
8 |
from mario_gpt.utils import view_level, convert_level_to_png
|
9 |
|
10 |
+
from fastapi import FastAPI
|
11 |
+
from fastapi.staticfiles import StaticFiles
|
12 |
|
13 |
+
import os
|
14 |
+
import uvicorn
|
15 |
|
16 |
mario_lm = MarioLM()
|
17 |
device = torch.device('cuda')
|
18 |
mario_lm = mario_lm.to(device)
|
19 |
TILE_DIR = "data/tiles"
|
20 |
|
21 |
+
app = FastAPI()
|
|
|
|
|
22 |
|
23 |
def make_html_file(generated_level):
|
24 |
level_text = f"""{'''
|
25 |
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
|
26 |
unique_id = uuid.uuid1()
|
27 |
+
with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
|
28 |
f.write(f'''<!DOCTYPE html>
|
29 |
<html lang="en">
|
30 |
|
|
|
41 |
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
|
42 |
}});
|
43 |
cheerpjCreateDisplay(512, 500);
|
44 |
+
cheerpjRunJar("/app/static/mario.jar");
|
45 |
</script>
|
46 |
</html>''')
|
47 |
return f"demo-{unique_id}.html"
|
|
|
60 |
filename = make_html_file(generated_level)
|
61 |
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
|
62 |
|
63 |
+
gradio_html = f'''<div>
|
64 |
+
<iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
|
65 |
+
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
|
66 |
</div>'''
|
67 |
return [img, gradio_html]
|
68 |
|
|
|
71 |
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
|
72 |
''')
|
73 |
with gr.Tabs():
|
|
|
|
|
74 |
with gr.TabItem("Compose prompt"):
|
75 |
with gr.Row():
|
76 |
+
pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
|
77 |
+
enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
|
78 |
with gr.Row():
|
79 |
+
blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
|
80 |
+
elevation = gr.Radio(["low", "high"], label="Elevation?")
|
81 |
+
with gr.TabItem("Type prompt"):
|
82 |
+
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
|
83 |
+
|
84 |
with gr.Accordion(label="Advanced settings", open=False):
|
85 |
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
|
86 |
level_size = gr.Number(value=1399, precision=0, label="level_size")
|
|
|
103 |
fn=generate,
|
104 |
cache_examples=True,
|
105 |
)
|
106 |
+
|
107 |
+
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
|
108 |
+
app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
|
109 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|