Spaces:
Running
Running
Banjo Obayomi
commited on
Commit
•
0a32c0e
1
Parent(s):
ac9d471
init commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +54 -0
- LICENSE +21 -0
- app.py +446 -0
- mario_gpt/__init__.py +17 -0
- mario_gpt/data/tiles/N.png +0 -0
- mario_gpt/data/tiles/Y.png +0 -0
- mario_gpt/data/tiles/cannon_bottom.png +0 -0
- mario_gpt/data/tiles/cannon_top.png +0 -0
- mario_gpt/data/tiles/flying_koopa.png +0 -0
- mario_gpt/data/tiles/ki-background.png +0 -0
- mario_gpt/data/tiles/ki-door.png +0 -0
- mario_gpt/data/tiles/ki-hazard.png +0 -0
- mario_gpt/data/tiles/ki-moving-platform.png +0 -0
- mario_gpt/data/tiles/ki-passable.png +0 -0
- mario_gpt/data/tiles/ki-path.png +0 -0
- mario_gpt/data/tiles/ki-unpassable.png +0 -0
- mario_gpt/data/tiles/mm-CMM.png +0 -0
- mario_gpt/data/tiles/mm-DMM.png +0 -0
- mario_gpt/data/tiles/mm-HMM.png +0 -0
- mario_gpt/data/tiles/mm-LMM.png +0 -0
- mario_gpt/data/tiles/mm-MMM.png +0 -0
- mario_gpt/data/tiles/mm-TMM.png +0 -0
- mario_gpt/data/tiles/mma_tiles.zip +3 -0
- mario_gpt/data/tiles/plant.png +0 -0
- mario_gpt/data/tiles/smb-background.png +0 -0
- mario_gpt/data/tiles/smb-breakable.png +0 -0
- mario_gpt/data/tiles/smb-coin.png +0 -0
- mario_gpt/data/tiles/smb-enemy.png +0 -0
- mario_gpt/data/tiles/smb-path.png +0 -0
- mario_gpt/data/tiles/smb-question.png +0 -0
- mario_gpt/data/tiles/smb-tube-lower-left.png +0 -0
- mario_gpt/data/tiles/smb-tube-lower-right.png +0 -0
- mario_gpt/data/tiles/smb-tube-top-left.png +0 -0
- mario_gpt/data/tiles/smb-tube-top-right.png +0 -0
- mario_gpt/data/tiles/smb-unpassable.png +0 -0
- mario_gpt/data/tiles/smb_enemies_sheet.png +0 -0
- mario_gpt/data/tiles/tile004 (1).png +0 -0
- mario_gpt/data/tiles/tile004 (2).png +0 -0
- mario_gpt/data/tiles/tile004.png +0 -0
- mario_gpt/dataset.py +138 -0
- mario_gpt/level.py +0 -0
- mario_gpt/lm/__init__.py +44 -0
- mario_gpt/lm/base.py +91 -0
- mario_gpt/lm/bert.py +95 -0
- mario_gpt/lm/gpt.py +97 -0
- mario_gpt/prompter.py +175 -0
- mario_gpt/sampler.py +370 -0
- mario_gpt/simulator/PlayAstar.jar +0 -0
- mario_gpt/simulator/PlayLevel.jar +0 -0
- mario_gpt/simulator/__init__.py +3 -0
.gitignore
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.py[cod]
|
2 |
+
|
3 |
+
# C extensions
|
4 |
+
*.so
|
5 |
+
|
6 |
+
# Packages
|
7 |
+
*.egg
|
8 |
+
*.egg-info
|
9 |
+
dist
|
10 |
+
build
|
11 |
+
eggs
|
12 |
+
parts
|
13 |
+
bin
|
14 |
+
var
|
15 |
+
sdist
|
16 |
+
develop-eggs
|
17 |
+
.installed.cfg
|
18 |
+
lib
|
19 |
+
lib64
|
20 |
+
__pycache__
|
21 |
+
|
22 |
+
# Installer logs
|
23 |
+
pip-log.txt
|
24 |
+
|
25 |
+
# Unit test / coverage reports
|
26 |
+
.coverage
|
27 |
+
.tox
|
28 |
+
nosetests.xml
|
29 |
+
|
30 |
+
# Translations
|
31 |
+
*.mo
|
32 |
+
|
33 |
+
# Mr Developer
|
34 |
+
.mr.developer.cfg
|
35 |
+
.project
|
36 |
+
.pydevproject
|
37 |
+
test.json
|
38 |
+
*.pickle
|
39 |
+
venv
|
40 |
+
.idea
|
41 |
+
*.vscode/
|
42 |
+
.DS_Store
|
43 |
+
|
44 |
+
#notebooks
|
45 |
+
*/**/.ipynb_checkpoints/*
|
46 |
+
|
47 |
+
# logs
|
48 |
+
*/**/checkpoints/*
|
49 |
+
*/**/mlruns/*
|
50 |
+
*/**/tensorboard_logs/*
|
51 |
+
*/**/wandb/*
|
52 |
+
checkpoints/*
|
53 |
+
wandb/*
|
54 |
+
mlruns/*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2023 Shyam Sudhakaran
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import uuid
|
3 |
+
from mario_gpt.lm import MarioLM
|
4 |
+
from mario_gpt.utils import convert_level_to_png
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from fastapi.staticfiles import StaticFiles
|
8 |
+
|
9 |
+
import uvicorn
|
10 |
+
import boto3
|
11 |
+
import json
|
12 |
+
|
13 |
+
bedrock_runtime = boto3.client(
|
14 |
+
service_name="bedrock-runtime",
|
15 |
+
region_name="us-east-1",
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def get_raw_text(level_data):
|
20 |
+
raw_text = ""
|
21 |
+
for line in level_data:
|
22 |
+
raw_text += line + "\n"
|
23 |
+
return raw_text
|
24 |
+
|
25 |
+
|
26 |
+
def combine_levels(level_arrays):
|
27 |
+
num_rows = len(level_arrays[0])
|
28 |
+
|
29 |
+
combined_level = []
|
30 |
+
|
31 |
+
for row in range(num_rows):
|
32 |
+
combined_row = ""
|
33 |
+
for level in level_arrays:
|
34 |
+
combined_row += level[row]
|
35 |
+
combined_level.append(combined_row)
|
36 |
+
|
37 |
+
return combined_level
|
38 |
+
|
39 |
+
|
40 |
+
def write_level_to_file(level_data, file_name):
|
41 |
+
with open(file_name, "w") as file:
|
42 |
+
for line in level_data:
|
43 |
+
file.write(line + "\n")
|
44 |
+
|
45 |
+
|
46 |
+
def clean_level_data(input_string):
|
47 |
+
# Find the start and end indices of the level data
|
48 |
+
start_index = input_string.find("[")
|
49 |
+
end_index = input_string.rfind("]")
|
50 |
+
|
51 |
+
# Extract the level data
|
52 |
+
level_data = input_string[start_index + 1 : end_index]
|
53 |
+
|
54 |
+
# Split the level data into lines
|
55 |
+
lines = level_data.split(",")
|
56 |
+
|
57 |
+
# Clean each line
|
58 |
+
cleaned_lines = []
|
59 |
+
for line in lines:
|
60 |
+
# Remove leading and trailing whitespace and quotes
|
61 |
+
cleaned_line = line.strip().strip("'")
|
62 |
+
|
63 |
+
# Ensure the line has exactly 50 characters
|
64 |
+
if len(cleaned_line) < 50:
|
65 |
+
cleaned_line += "-" * (50 - len(cleaned_line))
|
66 |
+
elif len(cleaned_line) > 50:
|
67 |
+
cleaned_line = cleaned_line[:50]
|
68 |
+
|
69 |
+
cleaned_lines.append(cleaned_line)
|
70 |
+
|
71 |
+
return cleaned_lines
|
72 |
+
|
73 |
+
|
74 |
+
def call_llama3_70b(system_prompt, prompt):
|
75 |
+
prompt_config = {
|
76 |
+
"prompt": system_prompt + prompt,
|
77 |
+
"max_gen_len": 2048,
|
78 |
+
"top_p": 0.9,
|
79 |
+
"temperature": 0.7,
|
80 |
+
}
|
81 |
+
|
82 |
+
body = json.dumps(prompt_config)
|
83 |
+
|
84 |
+
modelId = "meta.llama3-70b-instruct-v1:0"
|
85 |
+
accept = "application/json"
|
86 |
+
contentType = "application/json"
|
87 |
+
|
88 |
+
response = bedrock_runtime.invoke_model(
|
89 |
+
body=body, modelId=modelId, accept=accept, contentType=contentType
|
90 |
+
)
|
91 |
+
response_body = json.loads(response.get("body").read())
|
92 |
+
|
93 |
+
results = response_body["generation"].strip()
|
94 |
+
return results
|
95 |
+
|
96 |
+
|
97 |
+
def call_llama3_8b(system_prompt, prompt):
|
98 |
+
prompt_config = {
|
99 |
+
"prompt": system_prompt + prompt,
|
100 |
+
"max_gen_len": 2048,
|
101 |
+
"top_p": 0.9,
|
102 |
+
"temperature": 0.7,
|
103 |
+
}
|
104 |
+
|
105 |
+
body = json.dumps(prompt_config)
|
106 |
+
|
107 |
+
modelId = "meta.llama3-8b-instruct-v1:0"
|
108 |
+
accept = "application/json"
|
109 |
+
contentType = "application/json"
|
110 |
+
|
111 |
+
response = bedrock_runtime.invoke_model(
|
112 |
+
body=body, modelId=modelId, accept=accept, contentType=contentType
|
113 |
+
)
|
114 |
+
response_body = json.loads(response.get("body").read())
|
115 |
+
|
116 |
+
results = response_body["generation"].strip()
|
117 |
+
return results
|
118 |
+
|
119 |
+
|
120 |
+
# def call_claude_3_opus(system_prompt, prompt):
|
121 |
+
|
122 |
+
# prompt_config = {
|
123 |
+
# "anthropic_version": "bedrock-2023-05-31",
|
124 |
+
# "max_tokens": 4096,
|
125 |
+
# "system": system_prompt,
|
126 |
+
# "messages": [
|
127 |
+
# {
|
128 |
+
# "role": "user",
|
129 |
+
# "content": [
|
130 |
+
# {"type": "text", "text": prompt},
|
131 |
+
# ],
|
132 |
+
# }
|
133 |
+
# ],
|
134 |
+
# }
|
135 |
+
|
136 |
+
# body = json.dumps(prompt_config)
|
137 |
+
|
138 |
+
# modelId = "anthropic.claude-3-opus-20240229-v1:0"
|
139 |
+
# accept = "application/json"
|
140 |
+
# contentType = "application/json"
|
141 |
+
|
142 |
+
# response = bedrock_runtime.invoke_model(
|
143 |
+
# body=body, modelId=modelId, accept=accept, contentType=contentType
|
144 |
+
# )
|
145 |
+
# response_body = json.loads(response.get("body").read())
|
146 |
+
|
147 |
+
# results = response_body.get("content")[0].get("text")
|
148 |
+
# return results
|
149 |
+
|
150 |
+
|
151 |
+
# Call Claude model
|
152 |
+
def call_claude_3_sonnet(system_prompt, prompt):
|
153 |
+
|
154 |
+
prompt_config = {
|
155 |
+
"anthropic_version": "bedrock-2023-05-31",
|
156 |
+
"max_tokens": 4096,
|
157 |
+
"system": system_prompt,
|
158 |
+
"messages": [
|
159 |
+
{
|
160 |
+
"role": "user",
|
161 |
+
"content": [
|
162 |
+
{"type": "text", "text": prompt},
|
163 |
+
],
|
164 |
+
}
|
165 |
+
],
|
166 |
+
}
|
167 |
+
|
168 |
+
body = json.dumps(prompt_config)
|
169 |
+
|
170 |
+
modelId = "anthropic.claude-3-sonnet-20240229-v1:0"
|
171 |
+
accept = "application/json"
|
172 |
+
contentType = "application/json"
|
173 |
+
|
174 |
+
response = bedrock_runtime.invoke_model(
|
175 |
+
body=body, modelId=modelId, accept=accept, contentType=contentType
|
176 |
+
)
|
177 |
+
response_body = json.loads(response.get("body").read())
|
178 |
+
|
179 |
+
results = response_body.get("content")[0].get("text")
|
180 |
+
return results
|
181 |
+
|
182 |
+
|
183 |
+
def call_claude_3_haiku(system_prompt, prompt):
|
184 |
+
|
185 |
+
prompt_config = {
|
186 |
+
"anthropic_version": "bedrock-2023-05-31",
|
187 |
+
"max_tokens": 4096,
|
188 |
+
"system": system_prompt,
|
189 |
+
"messages": [
|
190 |
+
{
|
191 |
+
"role": "user",
|
192 |
+
"content": [
|
193 |
+
{"type": "text", "text": prompt},
|
194 |
+
],
|
195 |
+
}
|
196 |
+
],
|
197 |
+
}
|
198 |
+
|
199 |
+
body = json.dumps(prompt_config)
|
200 |
+
|
201 |
+
modelId = "anthropic.claude-3-haiku-20240307-v1:0"
|
202 |
+
accept = "application/json"
|
203 |
+
contentType = "application/json"
|
204 |
+
|
205 |
+
response = bedrock_runtime.invoke_model(
|
206 |
+
body=body, modelId=modelId, accept=accept, contentType=contentType
|
207 |
+
)
|
208 |
+
response_body = json.loads(response.get("body").read())
|
209 |
+
|
210 |
+
results = response_body.get("content")[0].get("text")
|
211 |
+
return results
|
212 |
+
|
213 |
+
|
214 |
+
system_prompt_text = """
|
215 |
+
As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker, you are tasked with crafting a playable section for the original Super Mario on NES. Your extensive experience and creativity are key to designing levels that are not only challenging but also immensely enjoyable. Use the following symbols to represent different game elements, ensuring each level is a masterpiece of design:
|
216 |
+
|
217 |
+
<symbols>
|
218 |
+
- = "Sky"
|
219 |
+
X = "Unbreakable Block"
|
220 |
+
E = "Enemy"
|
221 |
+
o = "Coin"
|
222 |
+
S = "Breakable Block"
|
223 |
+
? = "Question Block"
|
224 |
+
[] = "Pipe"
|
225 |
+
<> = "End of Pipe"
|
226 |
+
</symbols>
|
227 |
+
|
228 |
+
|
229 |
+
Adhere to these level layout specifications:
|
230 |
+
|
231 |
+
<level guidelines>
|
232 |
+
Pipes should be vertical and follow this format:
|
233 |
+
<>
|
234 |
+
[]
|
235 |
+
[]
|
236 |
+
|
237 |
+
Ensure there is a clear and navigable path that Mario can follow from the start to the end of the level. This path may involve jumping on blocks or pipes, running on blocks.
|
238 |
+
|
239 |
+
The path should be continuous and not lead Mario into any dead ends or impossible situations.
|
240 |
+
|
241 |
+
Place unbreakable blocks (X) or other platform elements strategically to create a solid foundation for Mario to walk on. Avoid creating large gaps or sections without any ground or platforms, as Mario needs a surface to stand on.
|
242 |
+
|
243 |
+
Adjust the complexity and elements based on the specific level request, ensuring that Mario can always complete the level successfully by following the designated path.
|
244 |
+
</level guidelines>
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
For example a prompt that asks for a level with one pipe and blocks with 2 Goombas.
|
249 |
+
|
250 |
+
Here is an example output:
|
251 |
+
<example>
|
252 |
+
['--------------------------------------------------',
|
253 |
+
'--------------------------------------------------',
|
254 |
+
'--------------------------------------------------',
|
255 |
+
'--------------------------------------------------',
|
256 |
+
'-------------------------------------------------o',
|
257 |
+
'--------XSSSSS---------------------------------SSS',
|
258 |
+
'--------X-----------------------------------------',
|
259 |
+
'--------X-----------------------------------------',
|
260 |
+
'-------EX--E-X---------------xxxx-?-----------xxxx',
|
261 |
+
'--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',
|
262 |
+
'---------------------------xx-[]--x---------xx----',
|
263 |
+
'--------------------------xx--[]---x-------xx-----',
|
264 |
+
'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',
|
265 |
+
'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']
|
266 |
+
</example>
|
267 |
+
|
268 |
+
|
269 |
+
Generate the level section as a 2D array, where each row is represented as a string of characters. The level section should be 14 rows tall and 50 columns wide. Only return the 2D array of characters.
|
270 |
+
|
271 |
+
Remember, your creations should challenge players but remain fair. Use your expertise to weave together obstacles and rewards, encouraging exploration and skillful play. Always ensure that Mario has a clear and navigable route to finish the level, and provide ample block tiles for Mario to walk on.
|
272 |
+
"""
|
273 |
+
|
274 |
+
|
275 |
+
mario_lm = MarioLM()
|
276 |
+
# device = torch.device('cuda')
|
277 |
+
# mario_lm = mario_lm.to(device)
|
278 |
+
TILE_DIR = "mario_gpt/data/tiles"
|
279 |
+
|
280 |
+
app = FastAPI()
|
281 |
+
|
282 |
+
|
283 |
+
def make_html_file(generated_level):
|
284 |
+
level_text = generated_level
|
285 |
+
unique_id = uuid.uuid1()
|
286 |
+
with open(f"static/demo-{unique_id}.html", "w", encoding="utf-8") as f:
|
287 |
+
f.write(
|
288 |
+
f"""<!DOCTYPE html>
|
289 |
+
<html lang="en">
|
290 |
+
<head>
|
291 |
+
<meta charset="utf-8">
|
292 |
+
<title>Mario Game</title>
|
293 |
+
<script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
|
294 |
+
</head>
|
295 |
+
<body>
|
296 |
+
</body>
|
297 |
+
<script>
|
298 |
+
cheerpjInit().then(function () {{
|
299 |
+
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
|
300 |
+
}});
|
301 |
+
cheerpjCreateDisplay(612, 600);
|
302 |
+
cheerpjRunJar("/app/static/mario.jar");
|
303 |
+
</script>
|
304 |
+
</html>"""
|
305 |
+
)
|
306 |
+
return f"demo-{unique_id}.html"
|
307 |
+
|
308 |
+
|
309 |
+
def generate(model, prompt, system_prompt=system_prompt_text):
|
310 |
+
|
311 |
+
print(f"Using prompt: {prompt}")
|
312 |
+
|
313 |
+
if system_prompt == "":
|
314 |
+
system_prompt = system_prompt_text
|
315 |
+
|
316 |
+
# # prompt 3 times
|
317 |
+
# prompts = [prompt, prompt, prompt]
|
318 |
+
|
319 |
+
# levels_array = []
|
320 |
+
|
321 |
+
# for index, prompt in enumerate(prompts):
|
322 |
+
|
323 |
+
# level = call_claude_3_sonnet(system_prompt, prompt)
|
324 |
+
# cleaned_level = clean_level_data(level)
|
325 |
+
|
326 |
+
# levels_array.append(cleaned_level)
|
327 |
+
|
328 |
+
# final_level = combine_levels(levels_array)
|
329 |
+
# raw_level_text = get_raw_text(final_level)
|
330 |
+
|
331 |
+
if model == "Claude Sonnet":
|
332 |
+
level = call_claude_3_sonnet(system_prompt, prompt)
|
333 |
+
elif model == "Claude Haiku":
|
334 |
+
level = call_claude_3_haiku(system_prompt, prompt)
|
335 |
+
elif model == "Llama3 70B":
|
336 |
+
level = call_llama3_70b(system_prompt, prompt)
|
337 |
+
elif model == "Llama3 8B":
|
338 |
+
level = call_llama3_8b(system_prompt, prompt)
|
339 |
+
# elif model == "Cladue Opus":
|
340 |
+
# level = call_claude_3_opus(system_prompt, prompt)
|
341 |
+
else:
|
342 |
+
raise ValueError("Invalid model")
|
343 |
+
|
344 |
+
# level = call_claude_3_sonnet(system_prompt, prompt)
|
345 |
+
cleaned_level = clean_level_data(level)
|
346 |
+
raw_level_text = get_raw_text(cleaned_level)
|
347 |
+
|
348 |
+
filename = make_html_file(raw_level_text)
|
349 |
+
img = convert_level_to_png(cleaned_level, mario_lm.tokenizer)[0]
|
350 |
+
|
351 |
+
gradio_html = f"""<div>
|
352 |
+
<iframe width=612 height=612 style="margin: 0 auto" src="static/{filename}"></iframe>
|
353 |
+
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
|
354 |
+
</div>"""
|
355 |
+
return [img, gradio_html]
|
356 |
+
|
357 |
+
|
358 |
+
with gr.Blocks().queue() as demo:
|
359 |
+
gr.Markdown(
|
360 |
+
"""### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models - Amazon Bedrock Edition
|
361 |
+
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981), [[Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/service_code_examples.html?trk=2403b700-9ee9-49e8-aed8-411dea5cf5ae&sc_channel=el)]
|
362 |
+
"""
|
363 |
+
)
|
364 |
+
with gr.Tabs():
|
365 |
+
with gr.TabItem("Prompt Settings"):
|
366 |
+
|
367 |
+
with gr.Accordion(label="System Prompt", open=False):
|
368 |
+
# temperature = gr.Number(
|
369 |
+
# value=2.0,
|
370 |
+
# label="temperature: Increase these for more diverse, but lower quality, generations",
|
371 |
+
# )
|
372 |
+
system_prompt = gr.TextArea(
|
373 |
+
value=system_prompt_text,
|
374 |
+
label="Enter your MarioGPT System prompt. ex: 'As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker...'",
|
375 |
+
)
|
376 |
+
|
377 |
+
text_prompt = gr.Textbox(
|
378 |
+
value="Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on",
|
379 |
+
label="Enter your MarioGPT prompt. ex: 'Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on'",
|
380 |
+
)
|
381 |
+
|
382 |
+
model = gr.Radio(
|
383 |
+
[
|
384 |
+
# "Cladue Opus", # no opus for demo
|
385 |
+
"Claude Sonnet",
|
386 |
+
"Claude Haiku",
|
387 |
+
"Llama3 70B",
|
388 |
+
"Llama3 8B",
|
389 |
+
],
|
390 |
+
label="Select Model",
|
391 |
+
value="Claude Sonnet",
|
392 |
+
)
|
393 |
+
|
394 |
+
# with gr.Accordion(label="Advanced settings", open=False):
|
395 |
+
# temperature = gr.Number(
|
396 |
+
# value=0.7,
|
397 |
+
# label="temperature: Increase for more randomness",
|
398 |
+
# )
|
399 |
+
# level_size = gr.Slider(
|
400 |
+
# value=1,
|
401 |
+
# minimum=1,
|
402 |
+
# maximum=5,
|
403 |
+
# step=1,
|
404 |
+
# label="level_size",
|
405 |
+
# )
|
406 |
+
|
407 |
+
btn = gr.Button("Generate level")
|
408 |
+
with gr.Row():
|
409 |
+
with gr.Group():
|
410 |
+
level_play = gr.HTML()
|
411 |
+
level_image = gr.Image()
|
412 |
+
btn.click(
|
413 |
+
fn=generate,
|
414 |
+
inputs=[
|
415 |
+
# temperature,
|
416 |
+
# level_size,
|
417 |
+
model,
|
418 |
+
text_prompt,
|
419 |
+
system_prompt,
|
420 |
+
],
|
421 |
+
outputs=[level_image, level_play],
|
422 |
+
)
|
423 |
+
gr.Examples(
|
424 |
+
examples=[
|
425 |
+
[
|
426 |
+
"Claude Sonnet",
|
427 |
+
"Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on",
|
428 |
+
],
|
429 |
+
[
|
430 |
+
"Claude Sonnet",
|
431 |
+
"Design a level with blocks arranged in a pyramid-like shape, with coins scattered around the base and goombas guarding the top. Have a pipe at the top.",
|
432 |
+
],
|
433 |
+
[
|
434 |
+
"Claude Sonnet",
|
435 |
+
"Make a simple level that has no enemies, but lots and lots of coins. Lots of blocks for mario to walk on.",
|
436 |
+
],
|
437 |
+
],
|
438 |
+
inputs=[model, text_prompt, system_prompt_text],
|
439 |
+
outputs=[level_image, level_play],
|
440 |
+
fn=generate,
|
441 |
+
cache_examples=True,
|
442 |
+
)
|
443 |
+
|
444 |
+
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
|
445 |
+
app = gr.mount_gradio_app(app, demo, "/")
|
446 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
mario_gpt/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mario_gpt.dataset import MarioDataset
|
2 |
+
from mario_gpt.lm import MarioBert, MarioGPT, MarioLM
|
3 |
+
from mario_gpt.prompter import Prompter
|
4 |
+
from mario_gpt.sampler import GPTSampler, SampleOutput
|
5 |
+
from mario_gpt.trainer import MarioGPTTrainer, TrainingConfig
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"Prompter",
|
9 |
+
"MarioDataset",
|
10 |
+
"MarioBert",
|
11 |
+
"MarioGPT",
|
12 |
+
"MarioLM",
|
13 |
+
"SampleOutput",
|
14 |
+
"GPTSampler",
|
15 |
+
"TrainingConfig",
|
16 |
+
"MarioGPTTrainer",
|
17 |
+
]
|
mario_gpt/data/tiles/N.png
ADDED
mario_gpt/data/tiles/Y.png
ADDED
mario_gpt/data/tiles/cannon_bottom.png
ADDED
mario_gpt/data/tiles/cannon_top.png
ADDED
mario_gpt/data/tiles/flying_koopa.png
ADDED
mario_gpt/data/tiles/ki-background.png
ADDED
mario_gpt/data/tiles/ki-door.png
ADDED
mario_gpt/data/tiles/ki-hazard.png
ADDED
mario_gpt/data/tiles/ki-moving-platform.png
ADDED
mario_gpt/data/tiles/ki-passable.png
ADDED
mario_gpt/data/tiles/ki-path.png
ADDED
mario_gpt/data/tiles/ki-unpassable.png
ADDED
mario_gpt/data/tiles/mm-CMM.png
ADDED
mario_gpt/data/tiles/mm-DMM.png
ADDED
mario_gpt/data/tiles/mm-HMM.png
ADDED
mario_gpt/data/tiles/mm-LMM.png
ADDED
mario_gpt/data/tiles/mm-MMM.png
ADDED
mario_gpt/data/tiles/mm-TMM.png
ADDED
mario_gpt/data/tiles/mma_tiles.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6d58bb3228bcd3c653c4a58b69044588ffd6e5e4c946a860497a39d84eb60b8
|
3 |
+
size 6586
|
mario_gpt/data/tiles/plant.png
ADDED
mario_gpt/data/tiles/smb-background.png
ADDED
mario_gpt/data/tiles/smb-breakable.png
ADDED
mario_gpt/data/tiles/smb-coin.png
ADDED
mario_gpt/data/tiles/smb-enemy.png
ADDED
mario_gpt/data/tiles/smb-path.png
ADDED
mario_gpt/data/tiles/smb-question.png
ADDED
mario_gpt/data/tiles/smb-tube-lower-left.png
ADDED
mario_gpt/data/tiles/smb-tube-lower-right.png
ADDED
mario_gpt/data/tiles/smb-tube-top-left.png
ADDED
mario_gpt/data/tiles/smb-tube-top-right.png
ADDED
mario_gpt/data/tiles/smb-unpassable.png
ADDED
mario_gpt/data/tiles/smb_enemies_sheet.png
ADDED
mario_gpt/data/tiles/tile004 (1).png
ADDED
mario_gpt/data/tiles/tile004 (2).png
ADDED
mario_gpt/data/tiles/tile004.png
ADDED
mario_gpt/dataset.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
9 |
+
|
10 |
+
from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS
|
11 |
+
|
12 |
+
DEFAULT_MODEL = "distilgpt2"
|
13 |
+
|
14 |
+
|
15 |
+
def split_given_size(a, size):
|
16 |
+
return np.split(a, np.arange(size, len(a), size))
|
17 |
+
|
18 |
+
|
19 |
+
def flip_and_transpose(arr: np.array, flip_first: bool = False):
|
20 |
+
if arr.shape[-1] > 1:
|
21 |
+
if flip_first:
|
22 |
+
return np.flip(arr, -1).transpose()
|
23 |
+
return np.flip(arr.transpose(), -1)
|
24 |
+
return arr
|
25 |
+
|
26 |
+
|
27 |
+
def join_list_of_list(str_lists):
|
28 |
+
return ["".join(s) for s in str_lists]
|
29 |
+
|
30 |
+
|
31 |
+
def characterize(str_lists):
|
32 |
+
return [list(s) for s in str_lists]
|
33 |
+
|
34 |
+
|
35 |
+
class MarioDataset(Dataset):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
39 |
+
level_string: Optional[str] = None,
|
40 |
+
context_len: int = 700,
|
41 |
+
height: int = 14,
|
42 |
+
remove_start_end_tokens: bool = False,
|
43 |
+
sample_all_indices: bool = False,
|
44 |
+
):
|
45 |
+
if level_string is None:
|
46 |
+
print(
|
47 |
+
"No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
|
48 |
+
)
|
49 |
+
level_string = FULL_LEVEL_STR_WITH_PATHS
|
50 |
+
elif ".txt" in level_string:
|
51 |
+
with open(level_string, "r") as file:
|
52 |
+
level_string = file.read()
|
53 |
+
|
54 |
+
self.character_set = set(level_string)
|
55 |
+
if "\n" in self.character_set:
|
56 |
+
self.character_set.remove("\n")
|
57 |
+
self.vocab_size = len(self.character_set)
|
58 |
+
self.sample_all_indices = sample_all_indices
|
59 |
+
|
60 |
+
def get_training_corpus():
|
61 |
+
yield list(level_string)
|
62 |
+
|
63 |
+
if tokenizer is None:
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
|
65 |
+
|
66 |
+
self.tokenizer = tokenizer
|
67 |
+
if getattr(tokenizer, "train_new_from_iterator", None) is not None:
|
68 |
+
self.tokenizer = self.tokenizer.train_new_from_iterator(
|
69 |
+
get_training_corpus(), 52000
|
70 |
+
)
|
71 |
+
elif getattr(tokenizer, "train_from_iterator", None) is not None:
|
72 |
+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
|
73 |
+
self.tokenizer = self.tokenizer.train_new_from_iterator(
|
74 |
+
get_training_corpus(), self.vocab_size
|
75 |
+
)
|
76 |
+
self.context_len = context_len
|
77 |
+
self.height = height
|
78 |
+
|
79 |
+
x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
|
80 |
+
self.input_ids = x["input_ids"].squeeze()
|
81 |
+
self.attention_masks = x["attention_mask"].squeeze()
|
82 |
+
if remove_start_end_tokens:
|
83 |
+
self.input_ids = self.input_ids[1:-1]
|
84 |
+
self.attention_masks = self.attention_masks[1:-1]
|
85 |
+
|
86 |
+
self.indices = self.generate_indices()
|
87 |
+
|
88 |
+
self.unique_tokens, self.unique_counts = self.input_ids.unique(
|
89 |
+
return_counts=True
|
90 |
+
)
|
91 |
+
self.weighted_unique_counts = (
|
92 |
+
1.0 / self.unique_counts / torch.sum(self.unique_counts)
|
93 |
+
)
|
94 |
+
|
95 |
+
self.token_dict = {}
|
96 |
+
string_tokens = list(self.tokenizer.decode(self.unique_tokens))
|
97 |
+
for int_token, string_token in zip(self.unique_tokens, string_tokens):
|
98 |
+
self.token_dict[string_token] = int_token
|
99 |
+
|
100 |
+
def convert_level_to_tensor(self, level: List[str]):
|
101 |
+
str_arr = flip_and_transpose(np.array(characterize(level)))
|
102 |
+
str_arr = "".join(join_list_of_list(str_arr))
|
103 |
+
|
104 |
+
x = self.tokenizer(str_arr, return_tensors="pt")
|
105 |
+
return x, str_arr
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return self.indices.shape[0]
|
109 |
+
|
110 |
+
def __getitem__(self, idx):
|
111 |
+
if isinstance(idx, int):
|
112 |
+
indices = self.indices[idx]
|
113 |
+
else:
|
114 |
+
indices = torch.stack([self.indices[i] for i in idx])
|
115 |
+
return self.input_ids[indices], self.attention_masks[indices]
|
116 |
+
|
117 |
+
def generate_indices(self):
|
118 |
+
out = []
|
119 |
+
for idx in range(self.input_ids.shape[0] - self.context_len):
|
120 |
+
if idx % self.height == 0 or self.sample_all_indices:
|
121 |
+
arange = torch.arange(idx, idx + self.context_len)
|
122 |
+
out.append(arange)
|
123 |
+
return torch.stack(out)
|
124 |
+
|
125 |
+
def sample_indices(self, batch_size):
|
126 |
+
out = []
|
127 |
+
for _ in range(batch_size):
|
128 |
+
start_idx = np.random.randint(0, self.__len__() - self.context_len)
|
129 |
+
indices = torch.arange(start_idx, start_idx + self.context_len)
|
130 |
+
out.append(indices)
|
131 |
+
return torch.stack(out)
|
132 |
+
|
133 |
+
def __str__(self):
|
134 |
+
str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
|
135 |
+
string = "\n".join(
|
136 |
+
join_list_of_list(flip_and_transpose(np.array(str_list), True))
|
137 |
+
)
|
138 |
+
return string
|
mario_gpt/level.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mario_gpt/lm/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
4 |
+
|
5 |
+
# lm stuff
|
6 |
+
from mario_gpt.lm.base import BaseMarioLM
|
7 |
+
from mario_gpt.lm.bert import MarioBert
|
8 |
+
from mario_gpt.lm.gpt import MarioGPT
|
9 |
+
from mario_gpt.prompter import Prompter
|
10 |
+
|
11 |
+
|
12 |
+
def MarioLM(
|
13 |
+
lm: Optional[PreTrainedModel] = None,
|
14 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
15 |
+
context_len: int = 700,
|
16 |
+
prompter: Optional[Prompter] = None,
|
17 |
+
mask_proportion: float = 0.15,
|
18 |
+
mask_model: bool = False,
|
19 |
+
lm_path: Optional[str] = None,
|
20 |
+
tokenizer_path: Optional[str] = None,
|
21 |
+
**kwargs
|
22 |
+
) -> Union[MarioGPT, MarioBert]:
|
23 |
+
if not mask_model:
|
24 |
+
return MarioGPT(
|
25 |
+
lm=lm,
|
26 |
+
tokenizer=tokenizer,
|
27 |
+
context_len=context_len,
|
28 |
+
prompter=prompter,
|
29 |
+
lm_path=lm_path,
|
30 |
+
tokenizer_path=tokenizer_path,
|
31 |
+
**kwargs
|
32 |
+
)
|
33 |
+
return MarioBert(
|
34 |
+
lm=lm,
|
35 |
+
tokenizer=tokenizer,
|
36 |
+
context_len=context_len,
|
37 |
+
mask_proportion=mask_proportion,
|
38 |
+
lm_path=lm_path,
|
39 |
+
tokenizer_path=tokenizer_path,
|
40 |
+
**kwargs
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
__all__ = ["BaseMarioLM", "MarioGPT", "MarioBert", "MarioLM"]
|
mario_gpt/lm/base.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
class BaseMarioLM(metaclass=abc.ABCMeta):
|
10 |
+
|
11 |
+
PRETRAINED_LM_PATH = ""
|
12 |
+
PRETRAINED_TOKENIZER_PATH = ""
|
13 |
+
|
14 |
+
BASE_LM_PATH = ""
|
15 |
+
BASE_TOKENIZER_PATH = ""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
lm: Optional[PreTrainedModel] = None,
|
20 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
21 |
+
context_len: int = 700,
|
22 |
+
lm_path: Optional[str] = None,
|
23 |
+
tokenizer_path: Optional[str] = None,
|
24 |
+
lm_kwargs: Dict[str, Any] = {},
|
25 |
+
tokenizer_kwargs: Dict[str, Any] = {},
|
26 |
+
):
|
27 |
+
self.load_pretrained(
|
28 |
+
lm_path, tokenizer_path, lm, tokenizer, lm_kwargs, tokenizer_kwargs
|
29 |
+
)
|
30 |
+
self.context_len = context_len
|
31 |
+
|
32 |
+
def train(self):
|
33 |
+
self.lm.train()
|
34 |
+
|
35 |
+
def eval(self):
|
36 |
+
self.lm.eval()
|
37 |
+
|
38 |
+
@property
|
39 |
+
def device(self):
|
40 |
+
return self.lm.device
|
41 |
+
|
42 |
+
def to(self, device: torch.device):
|
43 |
+
self.lm = self.lm.to(device)
|
44 |
+
return self
|
45 |
+
|
46 |
+
def save_model(self, checkpoint_path: str, it: int):
|
47 |
+
self.lm.save_pretrained(os.path.join(checkpoint_path, f"iteration_{it}"))
|
48 |
+
|
49 |
+
@abc.abstractmethod
|
50 |
+
def load_pretrained_lm(
|
51 |
+
self, path: str, lm_kwargs: Dict[str, Any]
|
52 |
+
) -> PreTrainedModel:
|
53 |
+
"""
|
54 |
+
Model to be used in level tile prediction
|
55 |
+
"""
|
56 |
+
|
57 |
+
@abc.abstractmethod
|
58 |
+
def load_pretrained_tokenizer(
|
59 |
+
self, path: str, tokenizer_kwargs: Dict[str, Any]
|
60 |
+
) -> PreTrainedTokenizer:
|
61 |
+
"""
|
62 |
+
Tokenizer to be used to read / decode levels
|
63 |
+
"""
|
64 |
+
|
65 |
+
def load_pretrained(
|
66 |
+
self,
|
67 |
+
lm_path: Optional[str] = None,
|
68 |
+
tokenizer_path: Optional[str] = None,
|
69 |
+
lm: Optional[PreTrainedModel] = None,
|
70 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
71 |
+
lm_kwargs: Dict[str, Any] = {},
|
72 |
+
tokenizer_kwargs: Dict[str, Any] = {},
|
73 |
+
):
|
74 |
+
self.lm = lm
|
75 |
+
self.tokenizer = tokenizer
|
76 |
+
|
77 |
+
if lm is None:
|
78 |
+
if lm_path is None:
|
79 |
+
lm_path = self.PRETRAINED_LM_PATH
|
80 |
+
|
81 |
+
print(f"Using {lm_path} lm")
|
82 |
+
self.lm = self.load_pretrained_lm(lm_path, lm_kwargs)
|
83 |
+
|
84 |
+
if tokenizer is None:
|
85 |
+
if tokenizer_path is None:
|
86 |
+
tokenizer_path = self.PRETRAINED_LM_PATH
|
87 |
+
|
88 |
+
print(f"Using {tokenizer_path} tokenizer")
|
89 |
+
self.tokenizer = self.load_pretrained_tokenizer(
|
90 |
+
tokenizer_path, tokenizer_kwargs
|
91 |
+
)
|
mario_gpt/lm/bert.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from transformers import (
|
8 |
+
AutoModelForMaskedLM,
|
9 |
+
AutoTokenizer,
|
10 |
+
PreTrainedModel,
|
11 |
+
PreTrainedTokenizer,
|
12 |
+
RobertaModel,
|
13 |
+
RobertaTokenizer,
|
14 |
+
)
|
15 |
+
|
16 |
+
from mario_gpt.lm.base import BaseMarioLM
|
17 |
+
|
18 |
+
PRETRAINED_MODEL_PATH = "shyamsn97/MarioBert-448-inpaint-context-length"
|
19 |
+
|
20 |
+
|
21 |
+
class MarioBert(BaseMarioLM):
|
22 |
+
PRETRAINED_LM_PATH = PRETRAINED_MODEL_PATH
|
23 |
+
PRETRAINED_TOKENIZER_PATH = PRETRAINED_MODEL_PATH
|
24 |
+
|
25 |
+
BASE_LM_PATH = "distilroberta-base"
|
26 |
+
BASE_TOKENIZER_PATH = "distilroberta-base"
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
lm: Optional[PreTrainedModel] = None,
|
31 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
32 |
+
context_len: int = 448,
|
33 |
+
mask_proportion: float = 0.16,
|
34 |
+
lm_path: Optional[str] = None,
|
35 |
+
tokenizer_path: Optional[str] = None,
|
36 |
+
lm_kwargs: Dict[str, Any] = {},
|
37 |
+
tokenizer_kwargs: Dict[str, Any] = {},
|
38 |
+
):
|
39 |
+
super().__init__(
|
40 |
+
lm,
|
41 |
+
tokenizer,
|
42 |
+
context_len,
|
43 |
+
lm_path,
|
44 |
+
tokenizer_path,
|
45 |
+
lm_kwargs,
|
46 |
+
tokenizer_kwargs,
|
47 |
+
)
|
48 |
+
self.mask_proportion = mask_proportion
|
49 |
+
self.mask_portion = int(self.context_len * self.mask_proportion)
|
50 |
+
|
51 |
+
def sample_mask(self, input_ids):
|
52 |
+
batch_size = input_ids.shape[0]
|
53 |
+
seq_len = input_ids.shape[-1]
|
54 |
+
mask_portion = self.mask_portion
|
55 |
+
sampled_start_idx = [i for i in range(seq_len - mask_portion) if i % 14 == 0]
|
56 |
+
sampled_start_idx = np.random.choice(sampled_start_idx, batch_size)
|
57 |
+
sampled_masks = []
|
58 |
+
for idx in sampled_start_idx:
|
59 |
+
mask = torch.arange(idx, idx + mask_portion)
|
60 |
+
sampled_masks.append(mask)
|
61 |
+
sampled_mask_indices = torch.stack(sampled_masks)
|
62 |
+
return self.apply_mask(input_ids, sampled_mask_indices)
|
63 |
+
|
64 |
+
def generate_mask(self, mask_len: int, batch_size: int = 1):
|
65 |
+
mask_token = self.tokenizer("<mask>").input_ids[1]
|
66 |
+
ones = torch.ones((batch_size, mask_len))
|
67 |
+
return ones * mask_token
|
68 |
+
|
69 |
+
def apply_mask(self, level, masked_indices, mask=None):
|
70 |
+
if len(level.shape) == 1:
|
71 |
+
level = level.unsqueeze(0)
|
72 |
+
batch_size = level.shape[0]
|
73 |
+
mask_len = masked_indices.shape[-1]
|
74 |
+
if mask is None:
|
75 |
+
mask = self.generate_mask(mask_len, batch_size)
|
76 |
+
mask = mask.long().to(level.device)
|
77 |
+
masked_level = level * torch.ones_like(level).to(level.device)
|
78 |
+
masked_level[:, masked_indices] = mask
|
79 |
+
return masked_level
|
80 |
+
|
81 |
+
def generate_seed(self, length: int, batch_size: Optional[int] = None):
|
82 |
+
seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()[
|
83 |
+
1:-1
|
84 |
+
] # remove start and end tokens
|
85 |
+
if batch_size is None:
|
86 |
+
return seed.repeat(length)
|
87 |
+
return seed.view(1, 1).repeat(batch_size, length)
|
88 |
+
|
89 |
+
def load_pretrained_lm(self, path: str, lm_kwargs: Dict[str, Any]) -> RobertaModel:
|
90 |
+
return AutoModelForMaskedLM.from_pretrained(path, **lm_kwargs)
|
91 |
+
|
92 |
+
def load_pretrained_tokenizer(
|
93 |
+
self, path: str, tokenizer_kwargs: Dict[str, Any]
|
94 |
+
) -> RobertaTokenizer:
|
95 |
+
return AutoTokenizer.from_pretrained(path, **tokenizer_kwargs)
|
mario_gpt/lm/gpt.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, List, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoModelWithLMHead,
|
9 |
+
AutoTokenizer,
|
10 |
+
GPT2Model,
|
11 |
+
GPT2Tokenizer,
|
12 |
+
PreTrainedModel,
|
13 |
+
PreTrainedTokenizer,
|
14 |
+
)
|
15 |
+
|
16 |
+
from mario_gpt.lm.base import BaseMarioLM
|
17 |
+
from mario_gpt.prompter import Prompter
|
18 |
+
from mario_gpt.sampler import GPTSampler, SampleOutput
|
19 |
+
|
20 |
+
PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
|
21 |
+
|
22 |
+
|
23 |
+
class MarioGPT(BaseMarioLM):
|
24 |
+
PRETRAINED_LM_PATH = PRETRAINED_MODEL_PATH
|
25 |
+
PRETRAINED_TOKENIZER_PATH = PRETRAINED_MODEL_PATH
|
26 |
+
|
27 |
+
BASE_LM_PATH = "distilgpt2"
|
28 |
+
BASE_TOKENIZER_PATH = "distilgpt2"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
lm: Optional[PreTrainedModel] = None,
|
33 |
+
tokenizer: Optional[PreTrainedTokenizer] = None,
|
34 |
+
context_len: int = 700,
|
35 |
+
prompter: Optional[Prompter] = None,
|
36 |
+
lm_path: Optional[str] = None,
|
37 |
+
tokenizer_path: Optional[str] = None,
|
38 |
+
lm_kwargs: Dict[str, Any] = {},
|
39 |
+
tokenizer_kwargs: Dict[str, Any] = {},
|
40 |
+
):
|
41 |
+
super().__init__(
|
42 |
+
lm,
|
43 |
+
tokenizer,
|
44 |
+
context_len,
|
45 |
+
lm_path,
|
46 |
+
tokenizer_path,
|
47 |
+
lm_kwargs,
|
48 |
+
tokenizer_kwargs,
|
49 |
+
)
|
50 |
+
self.prompter = prompter
|
51 |
+
if prompter is None:
|
52 |
+
self.prompter = Prompter(self.tokenizer)
|
53 |
+
|
54 |
+
def generate_seed(self, length: int, batch_size: Optional[int] = None):
|
55 |
+
seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()
|
56 |
+
if batch_size is None:
|
57 |
+
return seed.repeat(length)
|
58 |
+
return seed.view(1, 1).repeat(batch_size, length)
|
59 |
+
|
60 |
+
def load_pretrained_lm(self, path: str, lm_kwargs: Dict[str, Any]) -> GPT2Model:
|
61 |
+
if path == "random":
|
62 |
+
print("Initializing random weights...")
|
63 |
+
config = AutoConfig.from_pretrained(
|
64 |
+
self.BASE_LM_PATH, **{**lm_kwargs, "add_cross_attention": True}
|
65 |
+
)
|
66 |
+
return AutoModelWithLMHead.from_config(config)
|
67 |
+
return AutoModelWithLMHead.from_pretrained(
|
68 |
+
path, **{**lm_kwargs, "add_cross_attention": True}
|
69 |
+
)
|
70 |
+
|
71 |
+
def load_pretrained_tokenizer(
|
72 |
+
self, path: str, tokenizer_kwargs: Dict[str, Any]
|
73 |
+
) -> GPT2Tokenizer:
|
74 |
+
if path == "random":
|
75 |
+
return AutoTokenizer.from_pretrained(
|
76 |
+
self.BASE_TOKENIZER_PATH, **tokenizer_kwargs
|
77 |
+
)
|
78 |
+
return AutoTokenizer.from_pretrained(path, **tokenizer_kwargs)
|
79 |
+
|
80 |
+
def sample(
|
81 |
+
self,
|
82 |
+
seed: Optional[torch.Tensor] = None,
|
83 |
+
prompts: Optional[List[str]] = None,
|
84 |
+
num_steps: int = 1,
|
85 |
+
temperature: float = 2.0,
|
86 |
+
encoder_hidden_states: torch.Tensor = None,
|
87 |
+
use_tqdm: bool = False,
|
88 |
+
return_tensor: bool = False,
|
89 |
+
) -> SampleOutput:
|
90 |
+
sampler = GPTSampler(self, temperature, 16, self.context_len, use_tqdm)
|
91 |
+
return sampler(
|
92 |
+
seed=seed,
|
93 |
+
prompts=prompts,
|
94 |
+
num_steps=num_steps,
|
95 |
+
encoder_hidden_states=encoder_hidden_states,
|
96 |
+
return_tensor=return_tensor,
|
97 |
+
)
|
mario_gpt/prompter.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from scipy import stats
|
9 |
+
from transformers import pipeline
|
10 |
+
|
11 |
+
from mario_gpt.dataset import MarioDataset
|
12 |
+
from mario_gpt.utils import view_level
|
13 |
+
|
14 |
+
STATISTICS = {
|
15 |
+
"enemy": np.array([1.0, 3.0, 7.0]),
|
16 |
+
"pipe": np.array([0.0, 2.0, 5.0]),
|
17 |
+
"block": np.array([50.0, 75.0, 176.0]),
|
18 |
+
}
|
19 |
+
|
20 |
+
FEATURE_EXTRACTION_MODEL = "facebook/bart-base"
|
21 |
+
|
22 |
+
|
23 |
+
class Prompter:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
level_tokenizer,
|
27 |
+
prompter_model: str = FEATURE_EXTRACTION_MODEL,
|
28 |
+
use_raw_counts: bool = False,
|
29 |
+
statistics: Optional[Dict[str, Any]] = None,
|
30 |
+
):
|
31 |
+
self.prompter_model = prompter_model
|
32 |
+
self.feature_extraction = pipeline(
|
33 |
+
"feature-extraction",
|
34 |
+
model=prompter_model,
|
35 |
+
tokenizer=prompter_model,
|
36 |
+
framework="pt",
|
37 |
+
)
|
38 |
+
|
39 |
+
self.level_tokenizer = level_tokenizer
|
40 |
+
|
41 |
+
self.use_raw_counts = use_raw_counts
|
42 |
+
self.statistics = statistics
|
43 |
+
if statistics is None:
|
44 |
+
self.statistics = STATISTICS
|
45 |
+
|
46 |
+
@property
|
47 |
+
def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
|
48 |
+
thresholds = self.statistics["pipe"]
|
49 |
+
keywords = ["no", "little", "some", "many"]
|
50 |
+
return thresholds, keywords
|
51 |
+
|
52 |
+
@property
|
53 |
+
def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
|
54 |
+
thresholds = self.statistics["enemy"]
|
55 |
+
keywords = ["no", "little", "some", "many"]
|
56 |
+
return thresholds, keywords
|
57 |
+
|
58 |
+
@property
|
59 |
+
def block_thresholds(self) -> Tuple[List[int], List[str]]:
|
60 |
+
thresholds = self.statistics["block"]
|
61 |
+
keywords = ["little", "little", "some", "many"]
|
62 |
+
return thresholds, keywords
|
63 |
+
|
64 |
+
def count_pipes(self, flattened_level: str) -> int:
|
65 |
+
return flattened_level.count("<>")
|
66 |
+
|
67 |
+
def count_enemies(self, flattened_level: str) -> int:
|
68 |
+
return flattened_level.count("E") + flattened_level.count("B")
|
69 |
+
|
70 |
+
def count_blocks(self, flattened_level: str) -> int:
|
71 |
+
return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])
|
72 |
+
|
73 |
+
def _flatten_level(self, string_level: List[str]) -> str:
|
74 |
+
return "".join(string_level)
|
75 |
+
|
76 |
+
def pipe_prompt(self, flattened_level: str, level: str) -> str:
|
77 |
+
count = self.count_pipes(flattened_level)
|
78 |
+
keyword = f"{count}"
|
79 |
+
if not self.use_raw_counts:
|
80 |
+
thresholds, keywords = self.pipe_thresholds
|
81 |
+
threshold = np.digitize(count, thresholds, right=True)
|
82 |
+
keyword = keywords[threshold]
|
83 |
+
return f"{keyword} pipes", keyword
|
84 |
+
|
85 |
+
def enemy_prompt(self, flattened_level: str, level: str) -> str:
|
86 |
+
count = self.count_enemies(flattened_level)
|
87 |
+
keyword = f"{count}"
|
88 |
+
if not self.use_raw_counts:
|
89 |
+
thresholds, keywords = self.enemy_thresholds
|
90 |
+
threshold = np.digitize(count, thresholds, right=True)
|
91 |
+
keyword = keywords[threshold]
|
92 |
+
return f"{keyword} enemies", keyword
|
93 |
+
|
94 |
+
def block_prompt(self, flattened_level: str, level: str) -> str:
|
95 |
+
count = self.count_blocks(flattened_level)
|
96 |
+
keyword = f"{count}"
|
97 |
+
if not self.use_raw_counts:
|
98 |
+
thresholds, keywords = self.block_thresholds
|
99 |
+
threshold = np.digitize(count, thresholds, right=True)
|
100 |
+
keyword = keywords[threshold]
|
101 |
+
return f"{keyword} blocks", keyword
|
102 |
+
|
103 |
+
def elevation_prompt(self, flattened_level: str, level: str):
|
104 |
+
top_levels = level[:6] # elevation 8 and up
|
105 |
+
for t in top_levels:
|
106 |
+
if "X" in t or "<" in t or ">" in t:
|
107 |
+
return "high elevation", "high"
|
108 |
+
return "low elevation", "low"
|
109 |
+
|
110 |
+
def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
|
111 |
+
# Reducing along the first dimension to get a 768 dimensional array
|
112 |
+
return (
|
113 |
+
self.feature_extraction(prompt, return_tensors="pt")[0]
|
114 |
+
.mean(0)
|
115 |
+
.to(device)
|
116 |
+
.view(1, -1)
|
117 |
+
)
|
118 |
+
|
119 |
+
def dataset_statistics(self, dataset: MarioDataset):
|
120 |
+
enemy_counts = []
|
121 |
+
pipe_counts = []
|
122 |
+
block_counts = []
|
123 |
+
for i in range(len(dataset)):
|
124 |
+
level, _ = dataset[i]
|
125 |
+
str_level = self._flatten_level(view_level(level, dataset.tokenizer))
|
126 |
+
|
127 |
+
enemy_count = self.count_enemies(str_level)
|
128 |
+
pipe_count = self.count_pipes(str_level)
|
129 |
+
block_count = self.count_blocks(str_level)
|
130 |
+
|
131 |
+
enemy_counts.append(enemy_count)
|
132 |
+
pipe_counts.append(pipe_count)
|
133 |
+
block_counts.append(block_count)
|
134 |
+
d = {"enemy": {}, "pipe": {}, "block": {}}
|
135 |
+
|
136 |
+
d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
|
137 |
+
d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
|
138 |
+
d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
|
139 |
+
return d
|
140 |
+
|
141 |
+
def __call__(
|
142 |
+
self, level: torch.Tensor = None, sample_prompt: bool = False
|
143 |
+
) -> Union[str, torch.Tensor]:
|
144 |
+
device: torch.device = torch.device("cpu")
|
145 |
+
if not sample_prompt:
|
146 |
+
if level is None:
|
147 |
+
raise ValueError("Level must be provided if sample_prompt is not true!")
|
148 |
+
str_level = view_level(level, self.level_tokenizer)
|
149 |
+
flattened_level = self._flatten_level(str_level)
|
150 |
+
|
151 |
+
pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
|
152 |
+
enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
|
153 |
+
block_prompt, _ = self.block_prompt(flattened_level, str_level)
|
154 |
+
elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
|
155 |
+
device = level.device
|
156 |
+
else:
|
157 |
+
str_level = None
|
158 |
+
pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
|
159 |
+
enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
|
160 |
+
block_prompt = (
|
161 |
+
random.choice(["little", "little", "some", "many"]) + " blocks"
|
162 |
+
) # levels always have blocks
|
163 |
+
elevation_prompt = (
|
164 |
+
random.choice(["low", "high"]) + " elevation"
|
165 |
+
) # levels always have blocks
|
166 |
+
|
167 |
+
prompt_dict = {
|
168 |
+
"pipe": pipe_prompt,
|
169 |
+
"enemy": enemy_prompt,
|
170 |
+
"block": block_prompt,
|
171 |
+
"elevation_prompt": elevation_prompt,
|
172 |
+
}
|
173 |
+
prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
|
174 |
+
hidden = self.output_hidden(prompt, device=device)
|
175 |
+
return prompt, hidden, prompt_dict, str_level
|
mario_gpt/sampler.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL.Image import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
|
11 |
+
|
12 |
+
from mario_gpt.lm.base import BaseMarioLM
|
13 |
+
from mario_gpt.prompter import Prompter
|
14 |
+
from mario_gpt.simulator import Simulator
|
15 |
+
from mario_gpt.utils import (
|
16 |
+
convert_level_to_png,
|
17 |
+
load_level,
|
18 |
+
save_level,
|
19 |
+
trim_level,
|
20 |
+
view_level,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class SampleOutput:
|
26 |
+
level: Optional[List[str]]
|
27 |
+
prompt: Optional[str] = None
|
28 |
+
img: Optional[Image] = None
|
29 |
+
sample_predictions_str: Optional[List[str]] = None
|
30 |
+
sample_predictions_img: Optional[Image] = None
|
31 |
+
level_tensor: Optional[torch.Tensor] = None
|
32 |
+
sample_predictions_tensor: Optional[torch.Tensor] = None
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def create(
|
36 |
+
cls,
|
37 |
+
level_tensor: torch.Tensor,
|
38 |
+
sample_predictions_tensor: torch.Tensor,
|
39 |
+
tokenizer,
|
40 |
+
prompter: Optional[Prompter] = None,
|
41 |
+
) -> SampleOutput:
|
42 |
+
# batch = 1
|
43 |
+
level = None
|
44 |
+
img = None
|
45 |
+
|
46 |
+
try:
|
47 |
+
level = view_level(level_tensor, tokenizer)
|
48 |
+
img = convert_level_to_png(level)[0]
|
49 |
+
except Exception as e:
|
50 |
+
print(
|
51 |
+
f"Failed to generate string or image representation for full level! Got error {e}"
|
52 |
+
)
|
53 |
+
level = None
|
54 |
+
img = None
|
55 |
+
try:
|
56 |
+
sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
|
57 |
+
sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
|
58 |
+
except Exception as e:
|
59 |
+
print(
|
60 |
+
f"Failed to generate string or image representation for sampled predictions! Got error {e}"
|
61 |
+
)
|
62 |
+
sample_predictions_str = None
|
63 |
+
sample_predictions_img = None
|
64 |
+
|
65 |
+
prompt = None
|
66 |
+
if prompter is not None:
|
67 |
+
prompt = prompter(level_tensor)[0]
|
68 |
+
|
69 |
+
return SampleOutput(
|
70 |
+
level,
|
71 |
+
prompt,
|
72 |
+
img,
|
73 |
+
sample_predictions_str,
|
74 |
+
sample_predictions_img,
|
75 |
+
level_tensor,
|
76 |
+
sample_predictions_tensor,
|
77 |
+
)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def from_level_predictions(
|
81 |
+
cls,
|
82 |
+
level: torch.Tensor,
|
83 |
+
sample_predictions: torch.Tensor,
|
84 |
+
tokenizer,
|
85 |
+
prompter: Optional[Prompter] = None,
|
86 |
+
) -> Union[SampleOutput, List[SampleOutput]]:
|
87 |
+
level_tensor = trim_level(level).squeeze().detach().cpu()
|
88 |
+
sample_predictions_tensor = (
|
89 |
+
trim_level(sample_predictions).squeeze().detach().cpu()
|
90 |
+
)
|
91 |
+
|
92 |
+
if len(level_tensor.shape) == 1:
|
93 |
+
return SampleOutput.create(
|
94 |
+
level_tensor, sample_predictions_tensor, tokenizer, prompter
|
95 |
+
)
|
96 |
+
|
97 |
+
out = []
|
98 |
+
for _level_tensor, _sample_predictions_tensor in zip(
|
99 |
+
level_tensor, sample_predictions_tensor
|
100 |
+
):
|
101 |
+
sample_output = SampleOutput.create(
|
102 |
+
_level_tensor, _sample_predictions_tensor, tokenizer, prompter
|
103 |
+
)
|
104 |
+
out.append(sample_output)
|
105 |
+
return out
|
106 |
+
|
107 |
+
def save(self, filename: str) -> str:
|
108 |
+
save_level(self.level, filename)
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def load(cls, filename: str) -> SampleOutput:
|
112 |
+
level = load_level(filename)
|
113 |
+
return SampleOutput(level=level)
|
114 |
+
|
115 |
+
def play(self):
|
116 |
+
simulator = Simulator(level=self.level)
|
117 |
+
simulator.interactive()
|
118 |
+
|
119 |
+
def run_astar(self, render=True):
|
120 |
+
simulator = Simulator(level=self.level)
|
121 |
+
simulator.astar(render)
|
122 |
+
|
123 |
+
|
124 |
+
class GPTSampler:
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
mario_lm: BaseMarioLM,
|
128 |
+
temperature: float = 2.0,
|
129 |
+
top_k: int = 16,
|
130 |
+
context_len: int = 700,
|
131 |
+
use_tqdm: bool = False,
|
132 |
+
use_argmax: bool = False,
|
133 |
+
):
|
134 |
+
self.mario_lm = mario_lm
|
135 |
+
self.temperature = temperature
|
136 |
+
self.top_k = top_k
|
137 |
+
self.context_len = context_len
|
138 |
+
self.use_tqdm = use_tqdm
|
139 |
+
self.use_argmax = use_argmax
|
140 |
+
self.logits_processor = LogitsProcessorList()
|
141 |
+
self.logits_warper = LogitsProcessorList(
|
142 |
+
[
|
143 |
+
TopKLogitsWarper(top_k), # number of characters
|
144 |
+
TemperatureLogitsWarper(temperature),
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
+
@property
|
149 |
+
def device(self) -> torch.device:
|
150 |
+
return self.mario_lm.device
|
151 |
+
|
152 |
+
def step(
|
153 |
+
self,
|
154 |
+
seed: torch.Tensor,
|
155 |
+
encoder_hidden_states: torch.Tensor,
|
156 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
157 |
+
with torch.no_grad():
|
158 |
+
attention_mask = torch.ones_like(seed).to(seed.device)
|
159 |
+
input_ids = seed
|
160 |
+
out = self.mario_lm.lm(
|
161 |
+
input_ids=input_ids,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
encoder_hidden_states=encoder_hidden_states,
|
164 |
+
token_type_ids=None,
|
165 |
+
)
|
166 |
+
logits = out.logits.detach()
|
167 |
+
if len(logits.shape) == 2:
|
168 |
+
logits = logits.view(1, 1, -1)
|
169 |
+
next_token_logits = logits[:, -1, :]
|
170 |
+
|
171 |
+
if self.use_argmax:
|
172 |
+
next_tokens = next_token_logits.argmax(-1)
|
173 |
+
else:
|
174 |
+
next_token_scores = self.logits_processor(input_ids, next_token_logits)
|
175 |
+
next_token_scores = self.logits_warper(input_ids, next_token_scores)
|
176 |
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
177 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
178 |
+
return next_tokens, encoder_hidden_states
|
179 |
+
|
180 |
+
def sample(
|
181 |
+
self,
|
182 |
+
seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
|
183 |
+
prompts: Optional[List[str]] = None,
|
184 |
+
num_steps: int = 1,
|
185 |
+
encoder_hidden_states: torch.Tensor = None,
|
186 |
+
return_tensor: bool = False,
|
187 |
+
):
|
188 |
+
self.mario_lm.eval()
|
189 |
+
context_len = self.context_len - 28
|
190 |
+
with torch.no_grad():
|
191 |
+
if seed is None:
|
192 |
+
seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
|
193 |
+
self.device
|
194 |
+
)
|
195 |
+
out_tensor = seed.to(self.device)
|
196 |
+
elif isinstance(seed, SampleOutput):
|
197 |
+
out_tensor = seed.level_tensor.to(self.device).squeeze()
|
198 |
+
else:
|
199 |
+
out_tensor = seed.to(self.device).squeeze()
|
200 |
+
if len(out_tensor.shape) < 2:
|
201 |
+
# if we pass in a single seed vector, then we repeat for each prompt
|
202 |
+
# Otherwise, we treat inputs as separate seed-prompt pairs
|
203 |
+
out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
|
204 |
+
if encoder_hidden_states is None:
|
205 |
+
if prompts is not None:
|
206 |
+
encoder_hidden_states = torch.stack(
|
207 |
+
[
|
208 |
+
self.mario_lm.prompter.output_hidden(prompt)
|
209 |
+
for prompt in prompts
|
210 |
+
]
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
encoder_hidden_states = torch.stack(
|
214 |
+
[
|
215 |
+
self.mario_lm.prompter(sample_prompt=True)[1]
|
216 |
+
for _ in range(seed.shape[0])
|
217 |
+
]
|
218 |
+
)
|
219 |
+
encoder_hidden_states = encoder_hidden_states.to(
|
220 |
+
self.device
|
221 |
+
) # b x 1 x hidden_dim
|
222 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
223 |
+
out_tensor.shape[0], 1, -1
|
224 |
+
)
|
225 |
+
if not self.use_tqdm:
|
226 |
+
bar = np.arange(num_steps)
|
227 |
+
else:
|
228 |
+
bar = tqdm(np.arange(num_steps))
|
229 |
+
with torch.no_grad():
|
230 |
+
for i in bar:
|
231 |
+
inp = out_tensor * 1
|
232 |
+
if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
|
233 |
+
diff = inp.shape[-1] % 14 # height of mario level
|
234 |
+
ctx = context_len + diff
|
235 |
+
inp = inp[:, -ctx:] * 1
|
236 |
+
next_tokens, encoder_hidden_states = self.step(
|
237 |
+
inp,
|
238 |
+
encoder_hidden_states=encoder_hidden_states,
|
239 |
+
)
|
240 |
+
out_tensor = torch.cat(
|
241 |
+
[out_tensor, next_tokens.unsqueeze(-1)], dim=-1
|
242 |
+
)
|
243 |
+
if self.use_tqdm:
|
244 |
+
bar.set_description(
|
245 |
+
f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
|
246 |
+
)
|
247 |
+
if self.use_tqdm:
|
248 |
+
bar.close()
|
249 |
+
sample_out = SampleOutput.from_level_predictions(
|
250 |
+
out_tensor,
|
251 |
+
out_tensor[:, -num_steps:],
|
252 |
+
self.mario_lm.tokenizer,
|
253 |
+
self.mario_lm.prompter,
|
254 |
+
)
|
255 |
+
self.mario_lm.train()
|
256 |
+
if return_tensor:
|
257 |
+
return sample_out, out_tensor
|
258 |
+
return sample_out
|
259 |
+
|
260 |
+
def __call__(self, *args, **kwargs):
|
261 |
+
return self.sample(*args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
class BertSampler:
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
mario_lm: BaseMarioLM,
|
268 |
+
temperature: float = 2.0,
|
269 |
+
top_k: int = 16,
|
270 |
+
context_len: int = 448,
|
271 |
+
mask_proportion: float = 0.16,
|
272 |
+
):
|
273 |
+
self.mario_lm = mario_lm
|
274 |
+
self.temperature = temperature
|
275 |
+
self.top_k = top_k
|
276 |
+
self.logits_processor = LogitsProcessorList()
|
277 |
+
self.logits_warper = LogitsProcessorList(
|
278 |
+
[
|
279 |
+
TopKLogitsWarper(top_k), # number of characters
|
280 |
+
TemperatureLogitsWarper(temperature),
|
281 |
+
]
|
282 |
+
)
|
283 |
+
self.context_len = context_len
|
284 |
+
self.mask_proportion = mask_proportion
|
285 |
+
self.mask_portion = int(self.context_len * self.mask_proportion)
|
286 |
+
self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
|
287 |
+
|
288 |
+
@property
|
289 |
+
def device(self) -> torch.device:
|
290 |
+
return self.mario_lm.device
|
291 |
+
|
292 |
+
def get_context(self, input_ids, mask_indices):
|
293 |
+
start_idx = mask_indices[0]
|
294 |
+
end_idx = mask_indices[-1]
|
295 |
+
|
296 |
+
if input_ids.shape[-1] <= self.context_len:
|
297 |
+
clipped = input_ids.shape[-1] % 14
|
298 |
+
input_ids = input_ids[:clipped]
|
299 |
+
|
300 |
+
portion = (self.context_len - self.mask_portion) / 2
|
301 |
+
|
302 |
+
remainder = 0
|
303 |
+
left = start_idx - portion
|
304 |
+
if left < 0:
|
305 |
+
remainder = -1 * left
|
306 |
+
|
307 |
+
right = end_idx + portion + remainder
|
308 |
+
|
309 |
+
return input_ids[left:right]
|
310 |
+
|
311 |
+
def sample(
|
312 |
+
self,
|
313 |
+
seed: Union[torch.Tensor, SampleOutput],
|
314 |
+
mask: torch.Tensor,
|
315 |
+
return_tensor: bool = False,
|
316 |
+
):
|
317 |
+
self.mario_lm.eval()
|
318 |
+
mask_indices = mask.nonzero()
|
319 |
+
input_ids = seed
|
320 |
+
if isinstance(seed, SampleOutput):
|
321 |
+
input_ids = seed.level_tensor.to(self.device).squeeze()
|
322 |
+
|
323 |
+
input_id_list = []
|
324 |
+
for i in range(input_ids.shape[0]):
|
325 |
+
input_id = input_ids[i]
|
326 |
+
mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
327 |
+
input_id = self.get_context(input_id, mask_index)
|
328 |
+
input_id_list.append(input_id)
|
329 |
+
input_ids = torch.stack(input_ids, dim=0).to(self.device)
|
330 |
+
|
331 |
+
attention_mask = torch.ones_like(input_ids).to(seed.device)
|
332 |
+
|
333 |
+
if len(input_ids.shape) < 2:
|
334 |
+
# if we pass in a single seed vector, then we repeat for each prompt
|
335 |
+
# Otherwise, we treat inputs as separate seed-prompt pairs
|
336 |
+
input_ids = input_ids.view(1, -1)
|
337 |
+
|
338 |
+
out = self.mario_lm.lm(
|
339 |
+
input_ids=input_ids,
|
340 |
+
attention_mask=attention_mask,
|
341 |
+
token_type_ids=None,
|
342 |
+
)
|
343 |
+
logits = out.logits.detach()
|
344 |
+
if len(logits.shape) == 2:
|
345 |
+
logits = logits.view(1, 1, -1)
|
346 |
+
|
347 |
+
if self.use_argmax:
|
348 |
+
tokens = logits.argmax(-1)
|
349 |
+
else:
|
350 |
+
tokens_scores = self.logits_processor(input_ids, tokens)
|
351 |
+
tokens_scores = self.logits_warper(input_ids, tokens_scores)
|
352 |
+
probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
|
353 |
+
tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
354 |
+
|
355 |
+
out = input_ids.detach()
|
356 |
+
|
357 |
+
for i in range(input_ids.shape[0]):
|
358 |
+
mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
|
359 |
+
out[i, mask_index] = tokens[i, mask_index].detach()
|
360 |
+
|
361 |
+
sample_out = SampleOutput.from_level_predictions(
|
362 |
+
out,
|
363 |
+
tokens,
|
364 |
+
self.mario_lm.tokenizer,
|
365 |
+
self.mario_lm.prompter,
|
366 |
+
)
|
367 |
+
self.mario_lm.train()
|
368 |
+
if return_tensor:
|
369 |
+
return sample_out, tokens
|
370 |
+
return sample_out
|
mario_gpt/simulator/PlayAstar.jar
ADDED
Binary file (78.1 kB). View file
|
|
mario_gpt/simulator/PlayLevel.jar
ADDED
Binary file (78 kB). View file
|
|
mario_gpt/simulator/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from mario_gpt.simulator.simulator import Simulator
|
2 |
+
|
3 |
+
__all__ = ["Simulator"]
|