Banjo Obayomi commited on
Commit
0a32c0e
1 Parent(s): ac9d471

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +54 -0
  2. LICENSE +21 -0
  3. app.py +446 -0
  4. mario_gpt/__init__.py +17 -0
  5. mario_gpt/data/tiles/N.png +0 -0
  6. mario_gpt/data/tiles/Y.png +0 -0
  7. mario_gpt/data/tiles/cannon_bottom.png +0 -0
  8. mario_gpt/data/tiles/cannon_top.png +0 -0
  9. mario_gpt/data/tiles/flying_koopa.png +0 -0
  10. mario_gpt/data/tiles/ki-background.png +0 -0
  11. mario_gpt/data/tiles/ki-door.png +0 -0
  12. mario_gpt/data/tiles/ki-hazard.png +0 -0
  13. mario_gpt/data/tiles/ki-moving-platform.png +0 -0
  14. mario_gpt/data/tiles/ki-passable.png +0 -0
  15. mario_gpt/data/tiles/ki-path.png +0 -0
  16. mario_gpt/data/tiles/ki-unpassable.png +0 -0
  17. mario_gpt/data/tiles/mm-CMM.png +0 -0
  18. mario_gpt/data/tiles/mm-DMM.png +0 -0
  19. mario_gpt/data/tiles/mm-HMM.png +0 -0
  20. mario_gpt/data/tiles/mm-LMM.png +0 -0
  21. mario_gpt/data/tiles/mm-MMM.png +0 -0
  22. mario_gpt/data/tiles/mm-TMM.png +0 -0
  23. mario_gpt/data/tiles/mma_tiles.zip +3 -0
  24. mario_gpt/data/tiles/plant.png +0 -0
  25. mario_gpt/data/tiles/smb-background.png +0 -0
  26. mario_gpt/data/tiles/smb-breakable.png +0 -0
  27. mario_gpt/data/tiles/smb-coin.png +0 -0
  28. mario_gpt/data/tiles/smb-enemy.png +0 -0
  29. mario_gpt/data/tiles/smb-path.png +0 -0
  30. mario_gpt/data/tiles/smb-question.png +0 -0
  31. mario_gpt/data/tiles/smb-tube-lower-left.png +0 -0
  32. mario_gpt/data/tiles/smb-tube-lower-right.png +0 -0
  33. mario_gpt/data/tiles/smb-tube-top-left.png +0 -0
  34. mario_gpt/data/tiles/smb-tube-top-right.png +0 -0
  35. mario_gpt/data/tiles/smb-unpassable.png +0 -0
  36. mario_gpt/data/tiles/smb_enemies_sheet.png +0 -0
  37. mario_gpt/data/tiles/tile004 (1).png +0 -0
  38. mario_gpt/data/tiles/tile004 (2).png +0 -0
  39. mario_gpt/data/tiles/tile004.png +0 -0
  40. mario_gpt/dataset.py +138 -0
  41. mario_gpt/level.py +0 -0
  42. mario_gpt/lm/__init__.py +44 -0
  43. mario_gpt/lm/base.py +91 -0
  44. mario_gpt/lm/bert.py +95 -0
  45. mario_gpt/lm/gpt.py +97 -0
  46. mario_gpt/prompter.py +175 -0
  47. mario_gpt/sampler.py +370 -0
  48. mario_gpt/simulator/PlayAstar.jar +0 -0
  49. mario_gpt/simulator/PlayLevel.jar +0 -0
  50. mario_gpt/simulator/__init__.py +3 -0
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.py[cod]
2
+
3
+ # C extensions
4
+ *.so
5
+
6
+ # Packages
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+ eggs
12
+ parts
13
+ bin
14
+ var
15
+ sdist
16
+ develop-eggs
17
+ .installed.cfg
18
+ lib
19
+ lib64
20
+ __pycache__
21
+
22
+ # Installer logs
23
+ pip-log.txt
24
+
25
+ # Unit test / coverage reports
26
+ .coverage
27
+ .tox
28
+ nosetests.xml
29
+
30
+ # Translations
31
+ *.mo
32
+
33
+ # Mr Developer
34
+ .mr.developer.cfg
35
+ .project
36
+ .pydevproject
37
+ test.json
38
+ *.pickle
39
+ venv
40
+ .idea
41
+ *.vscode/
42
+ .DS_Store
43
+
44
+ #notebooks
45
+ */**/.ipynb_checkpoints/*
46
+
47
+ # logs
48
+ */**/checkpoints/*
49
+ */**/mlruns/*
50
+ */**/tensorboard_logs/*
51
+ */**/wandb/*
52
+ checkpoints/*
53
+ wandb/*
54
+ mlruns/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2023 Shyam Sudhakaran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import uuid
3
+ from mario_gpt.lm import MarioLM
4
+ from mario_gpt.utils import convert_level_to_png
5
+
6
+ from fastapi import FastAPI
7
+ from fastapi.staticfiles import StaticFiles
8
+
9
+ import uvicorn
10
+ import boto3
11
+ import json
12
+
13
+ bedrock_runtime = boto3.client(
14
+ service_name="bedrock-runtime",
15
+ region_name="us-east-1",
16
+ )
17
+
18
+
19
+ def get_raw_text(level_data):
20
+ raw_text = ""
21
+ for line in level_data:
22
+ raw_text += line + "\n"
23
+ return raw_text
24
+
25
+
26
+ def combine_levels(level_arrays):
27
+ num_rows = len(level_arrays[0])
28
+
29
+ combined_level = []
30
+
31
+ for row in range(num_rows):
32
+ combined_row = ""
33
+ for level in level_arrays:
34
+ combined_row += level[row]
35
+ combined_level.append(combined_row)
36
+
37
+ return combined_level
38
+
39
+
40
+ def write_level_to_file(level_data, file_name):
41
+ with open(file_name, "w") as file:
42
+ for line in level_data:
43
+ file.write(line + "\n")
44
+
45
+
46
+ def clean_level_data(input_string):
47
+ # Find the start and end indices of the level data
48
+ start_index = input_string.find("[")
49
+ end_index = input_string.rfind("]")
50
+
51
+ # Extract the level data
52
+ level_data = input_string[start_index + 1 : end_index]
53
+
54
+ # Split the level data into lines
55
+ lines = level_data.split(",")
56
+
57
+ # Clean each line
58
+ cleaned_lines = []
59
+ for line in lines:
60
+ # Remove leading and trailing whitespace and quotes
61
+ cleaned_line = line.strip().strip("'")
62
+
63
+ # Ensure the line has exactly 50 characters
64
+ if len(cleaned_line) < 50:
65
+ cleaned_line += "-" * (50 - len(cleaned_line))
66
+ elif len(cleaned_line) > 50:
67
+ cleaned_line = cleaned_line[:50]
68
+
69
+ cleaned_lines.append(cleaned_line)
70
+
71
+ return cleaned_lines
72
+
73
+
74
+ def call_llama3_70b(system_prompt, prompt):
75
+ prompt_config = {
76
+ "prompt": system_prompt + prompt,
77
+ "max_gen_len": 2048,
78
+ "top_p": 0.9,
79
+ "temperature": 0.7,
80
+ }
81
+
82
+ body = json.dumps(prompt_config)
83
+
84
+ modelId = "meta.llama3-70b-instruct-v1:0"
85
+ accept = "application/json"
86
+ contentType = "application/json"
87
+
88
+ response = bedrock_runtime.invoke_model(
89
+ body=body, modelId=modelId, accept=accept, contentType=contentType
90
+ )
91
+ response_body = json.loads(response.get("body").read())
92
+
93
+ results = response_body["generation"].strip()
94
+ return results
95
+
96
+
97
+ def call_llama3_8b(system_prompt, prompt):
98
+ prompt_config = {
99
+ "prompt": system_prompt + prompt,
100
+ "max_gen_len": 2048,
101
+ "top_p": 0.9,
102
+ "temperature": 0.7,
103
+ }
104
+
105
+ body = json.dumps(prompt_config)
106
+
107
+ modelId = "meta.llama3-8b-instruct-v1:0"
108
+ accept = "application/json"
109
+ contentType = "application/json"
110
+
111
+ response = bedrock_runtime.invoke_model(
112
+ body=body, modelId=modelId, accept=accept, contentType=contentType
113
+ )
114
+ response_body = json.loads(response.get("body").read())
115
+
116
+ results = response_body["generation"].strip()
117
+ return results
118
+
119
+
120
+ # def call_claude_3_opus(system_prompt, prompt):
121
+
122
+ # prompt_config = {
123
+ # "anthropic_version": "bedrock-2023-05-31",
124
+ # "max_tokens": 4096,
125
+ # "system": system_prompt,
126
+ # "messages": [
127
+ # {
128
+ # "role": "user",
129
+ # "content": [
130
+ # {"type": "text", "text": prompt},
131
+ # ],
132
+ # }
133
+ # ],
134
+ # }
135
+
136
+ # body = json.dumps(prompt_config)
137
+
138
+ # modelId = "anthropic.claude-3-opus-20240229-v1:0"
139
+ # accept = "application/json"
140
+ # contentType = "application/json"
141
+
142
+ # response = bedrock_runtime.invoke_model(
143
+ # body=body, modelId=modelId, accept=accept, contentType=contentType
144
+ # )
145
+ # response_body = json.loads(response.get("body").read())
146
+
147
+ # results = response_body.get("content")[0].get("text")
148
+ # return results
149
+
150
+
151
+ # Call Claude model
152
+ def call_claude_3_sonnet(system_prompt, prompt):
153
+
154
+ prompt_config = {
155
+ "anthropic_version": "bedrock-2023-05-31",
156
+ "max_tokens": 4096,
157
+ "system": system_prompt,
158
+ "messages": [
159
+ {
160
+ "role": "user",
161
+ "content": [
162
+ {"type": "text", "text": prompt},
163
+ ],
164
+ }
165
+ ],
166
+ }
167
+
168
+ body = json.dumps(prompt_config)
169
+
170
+ modelId = "anthropic.claude-3-sonnet-20240229-v1:0"
171
+ accept = "application/json"
172
+ contentType = "application/json"
173
+
174
+ response = bedrock_runtime.invoke_model(
175
+ body=body, modelId=modelId, accept=accept, contentType=contentType
176
+ )
177
+ response_body = json.loads(response.get("body").read())
178
+
179
+ results = response_body.get("content")[0].get("text")
180
+ return results
181
+
182
+
183
+ def call_claude_3_haiku(system_prompt, prompt):
184
+
185
+ prompt_config = {
186
+ "anthropic_version": "bedrock-2023-05-31",
187
+ "max_tokens": 4096,
188
+ "system": system_prompt,
189
+ "messages": [
190
+ {
191
+ "role": "user",
192
+ "content": [
193
+ {"type": "text", "text": prompt},
194
+ ],
195
+ }
196
+ ],
197
+ }
198
+
199
+ body = json.dumps(prompt_config)
200
+
201
+ modelId = "anthropic.claude-3-haiku-20240307-v1:0"
202
+ accept = "application/json"
203
+ contentType = "application/json"
204
+
205
+ response = bedrock_runtime.invoke_model(
206
+ body=body, modelId=modelId, accept=accept, contentType=contentType
207
+ )
208
+ response_body = json.loads(response.get("body").read())
209
+
210
+ results = response_body.get("content")[0].get("text")
211
+ return results
212
+
213
+
214
+ system_prompt_text = """
215
+ As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker, you are tasked with crafting a playable section for the original Super Mario on NES. Your extensive experience and creativity are key to designing levels that are not only challenging but also immensely enjoyable. Use the following symbols to represent different game elements, ensuring each level is a masterpiece of design:
216
+
217
+ <symbols>
218
+ - = "Sky"
219
+ X = "Unbreakable Block"
220
+ E = "Enemy"
221
+ o = "Coin"
222
+ S = "Breakable Block"
223
+ ? = "Question Block"
224
+ [] = "Pipe"
225
+ <> = "End of Pipe"
226
+ </symbols>
227
+
228
+
229
+ Adhere to these level layout specifications:
230
+
231
+ <level guidelines>
232
+ Pipes should be vertical and follow this format:
233
+ <>
234
+ []
235
+ []
236
+
237
+ Ensure there is a clear and navigable path that Mario can follow from the start to the end of the level. This path may involve jumping on blocks or pipes, running on blocks.
238
+
239
+ The path should be continuous and not lead Mario into any dead ends or impossible situations.
240
+
241
+ Place unbreakable blocks (X) or other platform elements strategically to create a solid foundation for Mario to walk on. Avoid creating large gaps or sections without any ground or platforms, as Mario needs a surface to stand on.
242
+
243
+ Adjust the complexity and elements based on the specific level request, ensuring that Mario can always complete the level successfully by following the designated path.
244
+ </level guidelines>
245
+
246
+
247
+
248
+ For example a prompt that asks for a level with one pipe and blocks with 2 Goombas.
249
+
250
+ Here is an example output:
251
+ <example>
252
+ ['--------------------------------------------------',
253
+ '--------------------------------------------------',
254
+ '--------------------------------------------------',
255
+ '--------------------------------------------------',
256
+ '-------------------------------------------------o',
257
+ '--------XSSSSS---------------------------------SSS',
258
+ '--------X-----------------------------------------',
259
+ '--------X-----------------------------------------',
260
+ '-------EX--E-X---------------xxxx-?-----------xxxx',
261
+ '--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',
262
+ '---------------------------xx-[]--x---------xx----',
263
+ '--------------------------xx--[]---x-------xx-----',
264
+ 'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',
265
+ 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']
266
+ </example>
267
+
268
+
269
+ Generate the level section as a 2D array, where each row is represented as a string of characters. The level section should be 14 rows tall and 50 columns wide. Only return the 2D array of characters.
270
+
271
+ Remember, your creations should challenge players but remain fair. Use your expertise to weave together obstacles and rewards, encouraging exploration and skillful play. Always ensure that Mario has a clear and navigable route to finish the level, and provide ample block tiles for Mario to walk on.
272
+ """
273
+
274
+
275
+ mario_lm = MarioLM()
276
+ # device = torch.device('cuda')
277
+ # mario_lm = mario_lm.to(device)
278
+ TILE_DIR = "mario_gpt/data/tiles"
279
+
280
+ app = FastAPI()
281
+
282
+
283
+ def make_html_file(generated_level):
284
+ level_text = generated_level
285
+ unique_id = uuid.uuid1()
286
+ with open(f"static/demo-{unique_id}.html", "w", encoding="utf-8") as f:
287
+ f.write(
288
+ f"""<!DOCTYPE html>
289
+ <html lang="en">
290
+ <head>
291
+ <meta charset="utf-8">
292
+ <title>Mario Game</title>
293
+ <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
294
+ </head>
295
+ <body>
296
+ </body>
297
+ <script>
298
+ cheerpjInit().then(function () {{
299
+ cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
300
+ }});
301
+ cheerpjCreateDisplay(612, 600);
302
+ cheerpjRunJar("/app/static/mario.jar");
303
+ </script>
304
+ </html>"""
305
+ )
306
+ return f"demo-{unique_id}.html"
307
+
308
+
309
+ def generate(model, prompt, system_prompt=system_prompt_text):
310
+
311
+ print(f"Using prompt: {prompt}")
312
+
313
+ if system_prompt == "":
314
+ system_prompt = system_prompt_text
315
+
316
+ # # prompt 3 times
317
+ # prompts = [prompt, prompt, prompt]
318
+
319
+ # levels_array = []
320
+
321
+ # for index, prompt in enumerate(prompts):
322
+
323
+ # level = call_claude_3_sonnet(system_prompt, prompt)
324
+ # cleaned_level = clean_level_data(level)
325
+
326
+ # levels_array.append(cleaned_level)
327
+
328
+ # final_level = combine_levels(levels_array)
329
+ # raw_level_text = get_raw_text(final_level)
330
+
331
+ if model == "Claude Sonnet":
332
+ level = call_claude_3_sonnet(system_prompt, prompt)
333
+ elif model == "Claude Haiku":
334
+ level = call_claude_3_haiku(system_prompt, prompt)
335
+ elif model == "Llama3 70B":
336
+ level = call_llama3_70b(system_prompt, prompt)
337
+ elif model == "Llama3 8B":
338
+ level = call_llama3_8b(system_prompt, prompt)
339
+ # elif model == "Cladue Opus":
340
+ # level = call_claude_3_opus(system_prompt, prompt)
341
+ else:
342
+ raise ValueError("Invalid model")
343
+
344
+ # level = call_claude_3_sonnet(system_prompt, prompt)
345
+ cleaned_level = clean_level_data(level)
346
+ raw_level_text = get_raw_text(cleaned_level)
347
+
348
+ filename = make_html_file(raw_level_text)
349
+ img = convert_level_to_png(cleaned_level, mario_lm.tokenizer)[0]
350
+
351
+ gradio_html = f"""<div>
352
+ <iframe width=612 height=612 style="margin: 0 auto" src="static/{filename}"></iframe>
353
+ <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
354
+ </div>"""
355
+ return [img, gradio_html]
356
+
357
+
358
+ with gr.Blocks().queue() as demo:
359
+ gr.Markdown(
360
+ """### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models - Amazon Bedrock Edition
361
+ [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981), [[Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/userguide/service_code_examples.html?trk=2403b700-9ee9-49e8-aed8-411dea5cf5ae&sc_channel=el)]
362
+ """
363
+ )
364
+ with gr.Tabs():
365
+ with gr.TabItem("Prompt Settings"):
366
+
367
+ with gr.Accordion(label="System Prompt", open=False):
368
+ # temperature = gr.Number(
369
+ # value=2.0,
370
+ # label="temperature: Increase these for more diverse, but lower quality, generations",
371
+ # )
372
+ system_prompt = gr.TextArea(
373
+ value=system_prompt_text,
374
+ label="Enter your MarioGPT System prompt. ex: 'As an esteemed level designer renowned for creating some of the top 100 levels in Super Mario Maker...'",
375
+ )
376
+
377
+ text_prompt = gr.Textbox(
378
+ value="Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on",
379
+ label="Enter your MarioGPT prompt. ex: 'Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on'",
380
+ )
381
+
382
+ model = gr.Radio(
383
+ [
384
+ # "Cladue Opus", # no opus for demo
385
+ "Claude Sonnet",
386
+ "Claude Haiku",
387
+ "Llama3 70B",
388
+ "Llama3 8B",
389
+ ],
390
+ label="Select Model",
391
+ value="Claude Sonnet",
392
+ )
393
+
394
+ # with gr.Accordion(label="Advanced settings", open=False):
395
+ # temperature = gr.Number(
396
+ # value=0.7,
397
+ # label="temperature: Increase for more randomness",
398
+ # )
399
+ # level_size = gr.Slider(
400
+ # value=1,
401
+ # minimum=1,
402
+ # maximum=5,
403
+ # step=1,
404
+ # label="level_size",
405
+ # )
406
+
407
+ btn = gr.Button("Generate level")
408
+ with gr.Row():
409
+ with gr.Group():
410
+ level_play = gr.HTML()
411
+ level_image = gr.Image()
412
+ btn.click(
413
+ fn=generate,
414
+ inputs=[
415
+ # temperature,
416
+ # level_size,
417
+ model,
418
+ text_prompt,
419
+ system_prompt,
420
+ ],
421
+ outputs=[level_image, level_play],
422
+ )
423
+ gr.Examples(
424
+ examples=[
425
+ [
426
+ "Claude Sonnet",
427
+ "Generate a level with a few pipes, many coins. make sure there are only 10 enemies. Make sure there is a ground path Mario can walk on",
428
+ ],
429
+ [
430
+ "Claude Sonnet",
431
+ "Design a level with blocks arranged in a pyramid-like shape, with coins scattered around the base and goombas guarding the top. Have a pipe at the top.",
432
+ ],
433
+ [
434
+ "Claude Sonnet",
435
+ "Make a simple level that has no enemies, but lots and lots of coins. Lots of blocks for mario to walk on.",
436
+ ],
437
+ ],
438
+ inputs=[model, text_prompt, system_prompt_text],
439
+ outputs=[level_image, level_play],
440
+ fn=generate,
441
+ cache_examples=True,
442
+ )
443
+
444
+ app.mount("/static", StaticFiles(directory="static", html=True), name="static")
445
+ app = gr.mount_gradio_app(app, demo, "/")
446
+ uvicorn.run(app, host="0.0.0.0", port=7860)
mario_gpt/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mario_gpt.dataset import MarioDataset
2
+ from mario_gpt.lm import MarioBert, MarioGPT, MarioLM
3
+ from mario_gpt.prompter import Prompter
4
+ from mario_gpt.sampler import GPTSampler, SampleOutput
5
+ from mario_gpt.trainer import MarioGPTTrainer, TrainingConfig
6
+
7
+ __all__ = [
8
+ "Prompter",
9
+ "MarioDataset",
10
+ "MarioBert",
11
+ "MarioGPT",
12
+ "MarioLM",
13
+ "SampleOutput",
14
+ "GPTSampler",
15
+ "TrainingConfig",
16
+ "MarioGPTTrainer",
17
+ ]
mario_gpt/data/tiles/N.png ADDED
mario_gpt/data/tiles/Y.png ADDED
mario_gpt/data/tiles/cannon_bottom.png ADDED
mario_gpt/data/tiles/cannon_top.png ADDED
mario_gpt/data/tiles/flying_koopa.png ADDED
mario_gpt/data/tiles/ki-background.png ADDED
mario_gpt/data/tiles/ki-door.png ADDED
mario_gpt/data/tiles/ki-hazard.png ADDED
mario_gpt/data/tiles/ki-moving-platform.png ADDED
mario_gpt/data/tiles/ki-passable.png ADDED
mario_gpt/data/tiles/ki-path.png ADDED
mario_gpt/data/tiles/ki-unpassable.png ADDED
mario_gpt/data/tiles/mm-CMM.png ADDED
mario_gpt/data/tiles/mm-DMM.png ADDED
mario_gpt/data/tiles/mm-HMM.png ADDED
mario_gpt/data/tiles/mm-LMM.png ADDED
mario_gpt/data/tiles/mm-MMM.png ADDED
mario_gpt/data/tiles/mm-TMM.png ADDED
mario_gpt/data/tiles/mma_tiles.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6d58bb3228bcd3c653c4a58b69044588ffd6e5e4c946a860497a39d84eb60b8
3
+ size 6586
mario_gpt/data/tiles/plant.png ADDED
mario_gpt/data/tiles/smb-background.png ADDED
mario_gpt/data/tiles/smb-breakable.png ADDED
mario_gpt/data/tiles/smb-coin.png ADDED
mario_gpt/data/tiles/smb-enemy.png ADDED
mario_gpt/data/tiles/smb-path.png ADDED
mario_gpt/data/tiles/smb-question.png ADDED
mario_gpt/data/tiles/smb-tube-lower-left.png ADDED
mario_gpt/data/tiles/smb-tube-lower-right.png ADDED
mario_gpt/data/tiles/smb-tube-top-left.png ADDED
mario_gpt/data/tiles/smb-tube-top-right.png ADDED
mario_gpt/data/tiles/smb-unpassable.png ADDED
mario_gpt/data/tiles/smb_enemies_sheet.png ADDED
mario_gpt/data/tiles/tile004 (1).png ADDED
mario_gpt/data/tiles/tile004 (2).png ADDED
mario_gpt/data/tiles/tile004.png ADDED
mario_gpt/dataset.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
9
+
10
+ from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS
11
+
12
+ DEFAULT_MODEL = "distilgpt2"
13
+
14
+
15
+ def split_given_size(a, size):
16
+ return np.split(a, np.arange(size, len(a), size))
17
+
18
+
19
+ def flip_and_transpose(arr: np.array, flip_first: bool = False):
20
+ if arr.shape[-1] > 1:
21
+ if flip_first:
22
+ return np.flip(arr, -1).transpose()
23
+ return np.flip(arr.transpose(), -1)
24
+ return arr
25
+
26
+
27
+ def join_list_of_list(str_lists):
28
+ return ["".join(s) for s in str_lists]
29
+
30
+
31
+ def characterize(str_lists):
32
+ return [list(s) for s in str_lists]
33
+
34
+
35
+ class MarioDataset(Dataset):
36
+ def __init__(
37
+ self,
38
+ tokenizer: Optional[PreTrainedTokenizer] = None,
39
+ level_string: Optional[str] = None,
40
+ context_len: int = 700,
41
+ height: int = 14,
42
+ remove_start_end_tokens: bool = False,
43
+ sample_all_indices: bool = False,
44
+ ):
45
+ if level_string is None:
46
+ print(
47
+ "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
48
+ )
49
+ level_string = FULL_LEVEL_STR_WITH_PATHS
50
+ elif ".txt" in level_string:
51
+ with open(level_string, "r") as file:
52
+ level_string = file.read()
53
+
54
+ self.character_set = set(level_string)
55
+ if "\n" in self.character_set:
56
+ self.character_set.remove("\n")
57
+ self.vocab_size = len(self.character_set)
58
+ self.sample_all_indices = sample_all_indices
59
+
60
+ def get_training_corpus():
61
+ yield list(level_string)
62
+
63
+ if tokenizer is None:
64
+ tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
65
+
66
+ self.tokenizer = tokenizer
67
+ if getattr(tokenizer, "train_new_from_iterator", None) is not None:
68
+ self.tokenizer = self.tokenizer.train_new_from_iterator(
69
+ get_training_corpus(), 52000
70
+ )
71
+ elif getattr(tokenizer, "train_from_iterator", None) is not None:
72
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
73
+ self.tokenizer = self.tokenizer.train_new_from_iterator(
74
+ get_training_corpus(), self.vocab_size
75
+ )
76
+ self.context_len = context_len
77
+ self.height = height
78
+
79
+ x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
80
+ self.input_ids = x["input_ids"].squeeze()
81
+ self.attention_masks = x["attention_mask"].squeeze()
82
+ if remove_start_end_tokens:
83
+ self.input_ids = self.input_ids[1:-1]
84
+ self.attention_masks = self.attention_masks[1:-1]
85
+
86
+ self.indices = self.generate_indices()
87
+
88
+ self.unique_tokens, self.unique_counts = self.input_ids.unique(
89
+ return_counts=True
90
+ )
91
+ self.weighted_unique_counts = (
92
+ 1.0 / self.unique_counts / torch.sum(self.unique_counts)
93
+ )
94
+
95
+ self.token_dict = {}
96
+ string_tokens = list(self.tokenizer.decode(self.unique_tokens))
97
+ for int_token, string_token in zip(self.unique_tokens, string_tokens):
98
+ self.token_dict[string_token] = int_token
99
+
100
+ def convert_level_to_tensor(self, level: List[str]):
101
+ str_arr = flip_and_transpose(np.array(characterize(level)))
102
+ str_arr = "".join(join_list_of_list(str_arr))
103
+
104
+ x = self.tokenizer(str_arr, return_tensors="pt")
105
+ return x, str_arr
106
+
107
+ def __len__(self):
108
+ return self.indices.shape[0]
109
+
110
+ def __getitem__(self, idx):
111
+ if isinstance(idx, int):
112
+ indices = self.indices[idx]
113
+ else:
114
+ indices = torch.stack([self.indices[i] for i in idx])
115
+ return self.input_ids[indices], self.attention_masks[indices]
116
+
117
+ def generate_indices(self):
118
+ out = []
119
+ for idx in range(self.input_ids.shape[0] - self.context_len):
120
+ if idx % self.height == 0 or self.sample_all_indices:
121
+ arange = torch.arange(idx, idx + self.context_len)
122
+ out.append(arange)
123
+ return torch.stack(out)
124
+
125
+ def sample_indices(self, batch_size):
126
+ out = []
127
+ for _ in range(batch_size):
128
+ start_idx = np.random.randint(0, self.__len__() - self.context_len)
129
+ indices = torch.arange(start_idx, start_idx + self.context_len)
130
+ out.append(indices)
131
+ return torch.stack(out)
132
+
133
+ def __str__(self):
134
+ str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
135
+ string = "\n".join(
136
+ join_list_of_list(flip_and_transpose(np.array(str_list), True))
137
+ )
138
+ return string
mario_gpt/level.py ADDED
The diff for this file is too large to render. See raw diff
 
mario_gpt/lm/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ from transformers import PreTrainedModel, PreTrainedTokenizer
4
+
5
+ # lm stuff
6
+ from mario_gpt.lm.base import BaseMarioLM
7
+ from mario_gpt.lm.bert import MarioBert
8
+ from mario_gpt.lm.gpt import MarioGPT
9
+ from mario_gpt.prompter import Prompter
10
+
11
+
12
+ def MarioLM(
13
+ lm: Optional[PreTrainedModel] = None,
14
+ tokenizer: Optional[PreTrainedTokenizer] = None,
15
+ context_len: int = 700,
16
+ prompter: Optional[Prompter] = None,
17
+ mask_proportion: float = 0.15,
18
+ mask_model: bool = False,
19
+ lm_path: Optional[str] = None,
20
+ tokenizer_path: Optional[str] = None,
21
+ **kwargs
22
+ ) -> Union[MarioGPT, MarioBert]:
23
+ if not mask_model:
24
+ return MarioGPT(
25
+ lm=lm,
26
+ tokenizer=tokenizer,
27
+ context_len=context_len,
28
+ prompter=prompter,
29
+ lm_path=lm_path,
30
+ tokenizer_path=tokenizer_path,
31
+ **kwargs
32
+ )
33
+ return MarioBert(
34
+ lm=lm,
35
+ tokenizer=tokenizer,
36
+ context_len=context_len,
37
+ mask_proportion=mask_proportion,
38
+ lm_path=lm_path,
39
+ tokenizer_path=tokenizer_path,
40
+ **kwargs
41
+ )
42
+
43
+
44
+ __all__ = ["BaseMarioLM", "MarioGPT", "MarioBert", "MarioLM"]
mario_gpt/lm/base.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from transformers import PreTrainedModel, PreTrainedTokenizer
7
+
8
+
9
+ class BaseMarioLM(metaclass=abc.ABCMeta):
10
+
11
+ PRETRAINED_LM_PATH = ""
12
+ PRETRAINED_TOKENIZER_PATH = ""
13
+
14
+ BASE_LM_PATH = ""
15
+ BASE_TOKENIZER_PATH = ""
16
+
17
+ def __init__(
18
+ self,
19
+ lm: Optional[PreTrainedModel] = None,
20
+ tokenizer: Optional[PreTrainedTokenizer] = None,
21
+ context_len: int = 700,
22
+ lm_path: Optional[str] = None,
23
+ tokenizer_path: Optional[str] = None,
24
+ lm_kwargs: Dict[str, Any] = {},
25
+ tokenizer_kwargs: Dict[str, Any] = {},
26
+ ):
27
+ self.load_pretrained(
28
+ lm_path, tokenizer_path, lm, tokenizer, lm_kwargs, tokenizer_kwargs
29
+ )
30
+ self.context_len = context_len
31
+
32
+ def train(self):
33
+ self.lm.train()
34
+
35
+ def eval(self):
36
+ self.lm.eval()
37
+
38
+ @property
39
+ def device(self):
40
+ return self.lm.device
41
+
42
+ def to(self, device: torch.device):
43
+ self.lm = self.lm.to(device)
44
+ return self
45
+
46
+ def save_model(self, checkpoint_path: str, it: int):
47
+ self.lm.save_pretrained(os.path.join(checkpoint_path, f"iteration_{it}"))
48
+
49
+ @abc.abstractmethod
50
+ def load_pretrained_lm(
51
+ self, path: str, lm_kwargs: Dict[str, Any]
52
+ ) -> PreTrainedModel:
53
+ """
54
+ Model to be used in level tile prediction
55
+ """
56
+
57
+ @abc.abstractmethod
58
+ def load_pretrained_tokenizer(
59
+ self, path: str, tokenizer_kwargs: Dict[str, Any]
60
+ ) -> PreTrainedTokenizer:
61
+ """
62
+ Tokenizer to be used to read / decode levels
63
+ """
64
+
65
+ def load_pretrained(
66
+ self,
67
+ lm_path: Optional[str] = None,
68
+ tokenizer_path: Optional[str] = None,
69
+ lm: Optional[PreTrainedModel] = None,
70
+ tokenizer: Optional[PreTrainedTokenizer] = None,
71
+ lm_kwargs: Dict[str, Any] = {},
72
+ tokenizer_kwargs: Dict[str, Any] = {},
73
+ ):
74
+ self.lm = lm
75
+ self.tokenizer = tokenizer
76
+
77
+ if lm is None:
78
+ if lm_path is None:
79
+ lm_path = self.PRETRAINED_LM_PATH
80
+
81
+ print(f"Using {lm_path} lm")
82
+ self.lm = self.load_pretrained_lm(lm_path, lm_kwargs)
83
+
84
+ if tokenizer is None:
85
+ if tokenizer_path is None:
86
+ tokenizer_path = self.PRETRAINED_LM_PATH
87
+
88
+ print(f"Using {tokenizer_path} tokenizer")
89
+ self.tokenizer = self.load_pretrained_tokenizer(
90
+ tokenizer_path, tokenizer_kwargs
91
+ )
mario_gpt/lm/bert.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import (
8
+ AutoModelForMaskedLM,
9
+ AutoTokenizer,
10
+ PreTrainedModel,
11
+ PreTrainedTokenizer,
12
+ RobertaModel,
13
+ RobertaTokenizer,
14
+ )
15
+
16
+ from mario_gpt.lm.base import BaseMarioLM
17
+
18
+ PRETRAINED_MODEL_PATH = "shyamsn97/MarioBert-448-inpaint-context-length"
19
+
20
+
21
+ class MarioBert(BaseMarioLM):
22
+ PRETRAINED_LM_PATH = PRETRAINED_MODEL_PATH
23
+ PRETRAINED_TOKENIZER_PATH = PRETRAINED_MODEL_PATH
24
+
25
+ BASE_LM_PATH = "distilroberta-base"
26
+ BASE_TOKENIZER_PATH = "distilroberta-base"
27
+
28
+ def __init__(
29
+ self,
30
+ lm: Optional[PreTrainedModel] = None,
31
+ tokenizer: Optional[PreTrainedTokenizer] = None,
32
+ context_len: int = 448,
33
+ mask_proportion: float = 0.16,
34
+ lm_path: Optional[str] = None,
35
+ tokenizer_path: Optional[str] = None,
36
+ lm_kwargs: Dict[str, Any] = {},
37
+ tokenizer_kwargs: Dict[str, Any] = {},
38
+ ):
39
+ super().__init__(
40
+ lm,
41
+ tokenizer,
42
+ context_len,
43
+ lm_path,
44
+ tokenizer_path,
45
+ lm_kwargs,
46
+ tokenizer_kwargs,
47
+ )
48
+ self.mask_proportion = mask_proportion
49
+ self.mask_portion = int(self.context_len * self.mask_proportion)
50
+
51
+ def sample_mask(self, input_ids):
52
+ batch_size = input_ids.shape[0]
53
+ seq_len = input_ids.shape[-1]
54
+ mask_portion = self.mask_portion
55
+ sampled_start_idx = [i for i in range(seq_len - mask_portion) if i % 14 == 0]
56
+ sampled_start_idx = np.random.choice(sampled_start_idx, batch_size)
57
+ sampled_masks = []
58
+ for idx in sampled_start_idx:
59
+ mask = torch.arange(idx, idx + mask_portion)
60
+ sampled_masks.append(mask)
61
+ sampled_mask_indices = torch.stack(sampled_masks)
62
+ return self.apply_mask(input_ids, sampled_mask_indices)
63
+
64
+ def generate_mask(self, mask_len: int, batch_size: int = 1):
65
+ mask_token = self.tokenizer("<mask>").input_ids[1]
66
+ ones = torch.ones((batch_size, mask_len))
67
+ return ones * mask_token
68
+
69
+ def apply_mask(self, level, masked_indices, mask=None):
70
+ if len(level.shape) == 1:
71
+ level = level.unsqueeze(0)
72
+ batch_size = level.shape[0]
73
+ mask_len = masked_indices.shape[-1]
74
+ if mask is None:
75
+ mask = self.generate_mask(mask_len, batch_size)
76
+ mask = mask.long().to(level.device)
77
+ masked_level = level * torch.ones_like(level).to(level.device)
78
+ masked_level[:, masked_indices] = mask
79
+ return masked_level
80
+
81
+ def generate_seed(self, length: int, batch_size: Optional[int] = None):
82
+ seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()[
83
+ 1:-1
84
+ ] # remove start and end tokens
85
+ if batch_size is None:
86
+ return seed.repeat(length)
87
+ return seed.view(1, 1).repeat(batch_size, length)
88
+
89
+ def load_pretrained_lm(self, path: str, lm_kwargs: Dict[str, Any]) -> RobertaModel:
90
+ return AutoModelForMaskedLM.from_pretrained(path, **lm_kwargs)
91
+
92
+ def load_pretrained_tokenizer(
93
+ self, path: str, tokenizer_kwargs: Dict[str, Any]
94
+ ) -> RobertaTokenizer:
95
+ return AutoTokenizer.from_pretrained(path, **tokenizer_kwargs)
mario_gpt/lm/gpt.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelWithLMHead,
9
+ AutoTokenizer,
10
+ GPT2Model,
11
+ GPT2Tokenizer,
12
+ PreTrainedModel,
13
+ PreTrainedTokenizer,
14
+ )
15
+
16
+ from mario_gpt.lm.base import BaseMarioLM
17
+ from mario_gpt.prompter import Prompter
18
+ from mario_gpt.sampler import GPTSampler, SampleOutput
19
+
20
+ PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
21
+
22
+
23
+ class MarioGPT(BaseMarioLM):
24
+ PRETRAINED_LM_PATH = PRETRAINED_MODEL_PATH
25
+ PRETRAINED_TOKENIZER_PATH = PRETRAINED_MODEL_PATH
26
+
27
+ BASE_LM_PATH = "distilgpt2"
28
+ BASE_TOKENIZER_PATH = "distilgpt2"
29
+
30
+ def __init__(
31
+ self,
32
+ lm: Optional[PreTrainedModel] = None,
33
+ tokenizer: Optional[PreTrainedTokenizer] = None,
34
+ context_len: int = 700,
35
+ prompter: Optional[Prompter] = None,
36
+ lm_path: Optional[str] = None,
37
+ tokenizer_path: Optional[str] = None,
38
+ lm_kwargs: Dict[str, Any] = {},
39
+ tokenizer_kwargs: Dict[str, Any] = {},
40
+ ):
41
+ super().__init__(
42
+ lm,
43
+ tokenizer,
44
+ context_len,
45
+ lm_path,
46
+ tokenizer_path,
47
+ lm_kwargs,
48
+ tokenizer_kwargs,
49
+ )
50
+ self.prompter = prompter
51
+ if prompter is None:
52
+ self.prompter = Prompter(self.tokenizer)
53
+
54
+ def generate_seed(self, length: int, batch_size: Optional[int] = None):
55
+ seed = self.tokenizer("X", return_tensors="pt").input_ids.squeeze()
56
+ if batch_size is None:
57
+ return seed.repeat(length)
58
+ return seed.view(1, 1).repeat(batch_size, length)
59
+
60
+ def load_pretrained_lm(self, path: str, lm_kwargs: Dict[str, Any]) -> GPT2Model:
61
+ if path == "random":
62
+ print("Initializing random weights...")
63
+ config = AutoConfig.from_pretrained(
64
+ self.BASE_LM_PATH, **{**lm_kwargs, "add_cross_attention": True}
65
+ )
66
+ return AutoModelWithLMHead.from_config(config)
67
+ return AutoModelWithLMHead.from_pretrained(
68
+ path, **{**lm_kwargs, "add_cross_attention": True}
69
+ )
70
+
71
+ def load_pretrained_tokenizer(
72
+ self, path: str, tokenizer_kwargs: Dict[str, Any]
73
+ ) -> GPT2Tokenizer:
74
+ if path == "random":
75
+ return AutoTokenizer.from_pretrained(
76
+ self.BASE_TOKENIZER_PATH, **tokenizer_kwargs
77
+ )
78
+ return AutoTokenizer.from_pretrained(path, **tokenizer_kwargs)
79
+
80
+ def sample(
81
+ self,
82
+ seed: Optional[torch.Tensor] = None,
83
+ prompts: Optional[List[str]] = None,
84
+ num_steps: int = 1,
85
+ temperature: float = 2.0,
86
+ encoder_hidden_states: torch.Tensor = None,
87
+ use_tqdm: bool = False,
88
+ return_tensor: bool = False,
89
+ ) -> SampleOutput:
90
+ sampler = GPTSampler(self, temperature, 16, self.context_len, use_tqdm)
91
+ return sampler(
92
+ seed=seed,
93
+ prompts=prompts,
94
+ num_steps=num_steps,
95
+ encoder_hidden_states=encoder_hidden_states,
96
+ return_tensor=return_tensor,
97
+ )
mario_gpt/prompter.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from scipy import stats
9
+ from transformers import pipeline
10
+
11
+ from mario_gpt.dataset import MarioDataset
12
+ from mario_gpt.utils import view_level
13
+
14
+ STATISTICS = {
15
+ "enemy": np.array([1.0, 3.0, 7.0]),
16
+ "pipe": np.array([0.0, 2.0, 5.0]),
17
+ "block": np.array([50.0, 75.0, 176.0]),
18
+ }
19
+
20
+ FEATURE_EXTRACTION_MODEL = "facebook/bart-base"
21
+
22
+
23
+ class Prompter:
24
+ def __init__(
25
+ self,
26
+ level_tokenizer,
27
+ prompter_model: str = FEATURE_EXTRACTION_MODEL,
28
+ use_raw_counts: bool = False,
29
+ statistics: Optional[Dict[str, Any]] = None,
30
+ ):
31
+ self.prompter_model = prompter_model
32
+ self.feature_extraction = pipeline(
33
+ "feature-extraction",
34
+ model=prompter_model,
35
+ tokenizer=prompter_model,
36
+ framework="pt",
37
+ )
38
+
39
+ self.level_tokenizer = level_tokenizer
40
+
41
+ self.use_raw_counts = use_raw_counts
42
+ self.statistics = statistics
43
+ if statistics is None:
44
+ self.statistics = STATISTICS
45
+
46
+ @property
47
+ def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
48
+ thresholds = self.statistics["pipe"]
49
+ keywords = ["no", "little", "some", "many"]
50
+ return thresholds, keywords
51
+
52
+ @property
53
+ def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
54
+ thresholds = self.statistics["enemy"]
55
+ keywords = ["no", "little", "some", "many"]
56
+ return thresholds, keywords
57
+
58
+ @property
59
+ def block_thresholds(self) -> Tuple[List[int], List[str]]:
60
+ thresholds = self.statistics["block"]
61
+ keywords = ["little", "little", "some", "many"]
62
+ return thresholds, keywords
63
+
64
+ def count_pipes(self, flattened_level: str) -> int:
65
+ return flattened_level.count("<>")
66
+
67
+ def count_enemies(self, flattened_level: str) -> int:
68
+ return flattened_level.count("E") + flattened_level.count("B")
69
+
70
+ def count_blocks(self, flattened_level: str) -> int:
71
+ return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])
72
+
73
+ def _flatten_level(self, string_level: List[str]) -> str:
74
+ return "".join(string_level)
75
+
76
+ def pipe_prompt(self, flattened_level: str, level: str) -> str:
77
+ count = self.count_pipes(flattened_level)
78
+ keyword = f"{count}"
79
+ if not self.use_raw_counts:
80
+ thresholds, keywords = self.pipe_thresholds
81
+ threshold = np.digitize(count, thresholds, right=True)
82
+ keyword = keywords[threshold]
83
+ return f"{keyword} pipes", keyword
84
+
85
+ def enemy_prompt(self, flattened_level: str, level: str) -> str:
86
+ count = self.count_enemies(flattened_level)
87
+ keyword = f"{count}"
88
+ if not self.use_raw_counts:
89
+ thresholds, keywords = self.enemy_thresholds
90
+ threshold = np.digitize(count, thresholds, right=True)
91
+ keyword = keywords[threshold]
92
+ return f"{keyword} enemies", keyword
93
+
94
+ def block_prompt(self, flattened_level: str, level: str) -> str:
95
+ count = self.count_blocks(flattened_level)
96
+ keyword = f"{count}"
97
+ if not self.use_raw_counts:
98
+ thresholds, keywords = self.block_thresholds
99
+ threshold = np.digitize(count, thresholds, right=True)
100
+ keyword = keywords[threshold]
101
+ return f"{keyword} blocks", keyword
102
+
103
+ def elevation_prompt(self, flattened_level: str, level: str):
104
+ top_levels = level[:6] # elevation 8 and up
105
+ for t in top_levels:
106
+ if "X" in t or "<" in t or ">" in t:
107
+ return "high elevation", "high"
108
+ return "low elevation", "low"
109
+
110
+ def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
111
+ # Reducing along the first dimension to get a 768 dimensional array
112
+ return (
113
+ self.feature_extraction(prompt, return_tensors="pt")[0]
114
+ .mean(0)
115
+ .to(device)
116
+ .view(1, -1)
117
+ )
118
+
119
+ def dataset_statistics(self, dataset: MarioDataset):
120
+ enemy_counts = []
121
+ pipe_counts = []
122
+ block_counts = []
123
+ for i in range(len(dataset)):
124
+ level, _ = dataset[i]
125
+ str_level = self._flatten_level(view_level(level, dataset.tokenizer))
126
+
127
+ enemy_count = self.count_enemies(str_level)
128
+ pipe_count = self.count_pipes(str_level)
129
+ block_count = self.count_blocks(str_level)
130
+
131
+ enemy_counts.append(enemy_count)
132
+ pipe_counts.append(pipe_count)
133
+ block_counts.append(block_count)
134
+ d = {"enemy": {}, "pipe": {}, "block": {}}
135
+
136
+ d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
137
+ d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
138
+ d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
139
+ return d
140
+
141
+ def __call__(
142
+ self, level: torch.Tensor = None, sample_prompt: bool = False
143
+ ) -> Union[str, torch.Tensor]:
144
+ device: torch.device = torch.device("cpu")
145
+ if not sample_prompt:
146
+ if level is None:
147
+ raise ValueError("Level must be provided if sample_prompt is not true!")
148
+ str_level = view_level(level, self.level_tokenizer)
149
+ flattened_level = self._flatten_level(str_level)
150
+
151
+ pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
152
+ enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
153
+ block_prompt, _ = self.block_prompt(flattened_level, str_level)
154
+ elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
155
+ device = level.device
156
+ else:
157
+ str_level = None
158
+ pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
159
+ enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
160
+ block_prompt = (
161
+ random.choice(["little", "little", "some", "many"]) + " blocks"
162
+ ) # levels always have blocks
163
+ elevation_prompt = (
164
+ random.choice(["low", "high"]) + " elevation"
165
+ ) # levels always have blocks
166
+
167
+ prompt_dict = {
168
+ "pipe": pipe_prompt,
169
+ "enemy": enemy_prompt,
170
+ "block": block_prompt,
171
+ "elevation_prompt": elevation_prompt,
172
+ }
173
+ prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
174
+ hidden = self.output_hidden(prompt, device=device)
175
+ return prompt, hidden, prompt_dict, str_level
mario_gpt/sampler.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL.Image import Image
9
+ from tqdm import tqdm
10
+ from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper
11
+
12
+ from mario_gpt.lm.base import BaseMarioLM
13
+ from mario_gpt.prompter import Prompter
14
+ from mario_gpt.simulator import Simulator
15
+ from mario_gpt.utils import (
16
+ convert_level_to_png,
17
+ load_level,
18
+ save_level,
19
+ trim_level,
20
+ view_level,
21
+ )
22
+
23
+
24
+ @dataclass
25
+ class SampleOutput:
26
+ level: Optional[List[str]]
27
+ prompt: Optional[str] = None
28
+ img: Optional[Image] = None
29
+ sample_predictions_str: Optional[List[str]] = None
30
+ sample_predictions_img: Optional[Image] = None
31
+ level_tensor: Optional[torch.Tensor] = None
32
+ sample_predictions_tensor: Optional[torch.Tensor] = None
33
+
34
+ @classmethod
35
+ def create(
36
+ cls,
37
+ level_tensor: torch.Tensor,
38
+ sample_predictions_tensor: torch.Tensor,
39
+ tokenizer,
40
+ prompter: Optional[Prompter] = None,
41
+ ) -> SampleOutput:
42
+ # batch = 1
43
+ level = None
44
+ img = None
45
+
46
+ try:
47
+ level = view_level(level_tensor, tokenizer)
48
+ img = convert_level_to_png(level)[0]
49
+ except Exception as e:
50
+ print(
51
+ f"Failed to generate string or image representation for full level! Got error {e}"
52
+ )
53
+ level = None
54
+ img = None
55
+ try:
56
+ sample_predictions_str = view_level(sample_predictions_tensor, tokenizer)
57
+ sample_predictions_img = convert_level_to_png(sample_predictions_str)[0]
58
+ except Exception as e:
59
+ print(
60
+ f"Failed to generate string or image representation for sampled predictions! Got error {e}"
61
+ )
62
+ sample_predictions_str = None
63
+ sample_predictions_img = None
64
+
65
+ prompt = None
66
+ if prompter is not None:
67
+ prompt = prompter(level_tensor)[0]
68
+
69
+ return SampleOutput(
70
+ level,
71
+ prompt,
72
+ img,
73
+ sample_predictions_str,
74
+ sample_predictions_img,
75
+ level_tensor,
76
+ sample_predictions_tensor,
77
+ )
78
+
79
+ @classmethod
80
+ def from_level_predictions(
81
+ cls,
82
+ level: torch.Tensor,
83
+ sample_predictions: torch.Tensor,
84
+ tokenizer,
85
+ prompter: Optional[Prompter] = None,
86
+ ) -> Union[SampleOutput, List[SampleOutput]]:
87
+ level_tensor = trim_level(level).squeeze().detach().cpu()
88
+ sample_predictions_tensor = (
89
+ trim_level(sample_predictions).squeeze().detach().cpu()
90
+ )
91
+
92
+ if len(level_tensor.shape) == 1:
93
+ return SampleOutput.create(
94
+ level_tensor, sample_predictions_tensor, tokenizer, prompter
95
+ )
96
+
97
+ out = []
98
+ for _level_tensor, _sample_predictions_tensor in zip(
99
+ level_tensor, sample_predictions_tensor
100
+ ):
101
+ sample_output = SampleOutput.create(
102
+ _level_tensor, _sample_predictions_tensor, tokenizer, prompter
103
+ )
104
+ out.append(sample_output)
105
+ return out
106
+
107
+ def save(self, filename: str) -> str:
108
+ save_level(self.level, filename)
109
+
110
+ @classmethod
111
+ def load(cls, filename: str) -> SampleOutput:
112
+ level = load_level(filename)
113
+ return SampleOutput(level=level)
114
+
115
+ def play(self):
116
+ simulator = Simulator(level=self.level)
117
+ simulator.interactive()
118
+
119
+ def run_astar(self, render=True):
120
+ simulator = Simulator(level=self.level)
121
+ simulator.astar(render)
122
+
123
+
124
+ class GPTSampler:
125
+ def __init__(
126
+ self,
127
+ mario_lm: BaseMarioLM,
128
+ temperature: float = 2.0,
129
+ top_k: int = 16,
130
+ context_len: int = 700,
131
+ use_tqdm: bool = False,
132
+ use_argmax: bool = False,
133
+ ):
134
+ self.mario_lm = mario_lm
135
+ self.temperature = temperature
136
+ self.top_k = top_k
137
+ self.context_len = context_len
138
+ self.use_tqdm = use_tqdm
139
+ self.use_argmax = use_argmax
140
+ self.logits_processor = LogitsProcessorList()
141
+ self.logits_warper = LogitsProcessorList(
142
+ [
143
+ TopKLogitsWarper(top_k), # number of characters
144
+ TemperatureLogitsWarper(temperature),
145
+ ]
146
+ )
147
+
148
+ @property
149
+ def device(self) -> torch.device:
150
+ return self.mario_lm.device
151
+
152
+ def step(
153
+ self,
154
+ seed: torch.Tensor,
155
+ encoder_hidden_states: torch.Tensor,
156
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
157
+ with torch.no_grad():
158
+ attention_mask = torch.ones_like(seed).to(seed.device)
159
+ input_ids = seed
160
+ out = self.mario_lm.lm(
161
+ input_ids=input_ids,
162
+ attention_mask=attention_mask,
163
+ encoder_hidden_states=encoder_hidden_states,
164
+ token_type_ids=None,
165
+ )
166
+ logits = out.logits.detach()
167
+ if len(logits.shape) == 2:
168
+ logits = logits.view(1, 1, -1)
169
+ next_token_logits = logits[:, -1, :]
170
+
171
+ if self.use_argmax:
172
+ next_tokens = next_token_logits.argmax(-1)
173
+ else:
174
+ next_token_scores = self.logits_processor(input_ids, next_token_logits)
175
+ next_token_scores = self.logits_warper(input_ids, next_token_scores)
176
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
177
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
178
+ return next_tokens, encoder_hidden_states
179
+
180
+ def sample(
181
+ self,
182
+ seed: Union[Optional[torch.Tensor], Optional[SampleOutput]] = None,
183
+ prompts: Optional[List[str]] = None,
184
+ num_steps: int = 1,
185
+ encoder_hidden_states: torch.Tensor = None,
186
+ return_tensor: bool = False,
187
+ ):
188
+ self.mario_lm.eval()
189
+ context_len = self.context_len - 28
190
+ with torch.no_grad():
191
+ if seed is None:
192
+ seed = self.mario_lm.generate_seed(1, batch_size=len(prompts)).to(
193
+ self.device
194
+ )
195
+ out_tensor = seed.to(self.device)
196
+ elif isinstance(seed, SampleOutput):
197
+ out_tensor = seed.level_tensor.to(self.device).squeeze()
198
+ else:
199
+ out_tensor = seed.to(self.device).squeeze()
200
+ if len(out_tensor.shape) < 2:
201
+ # if we pass in a single seed vector, then we repeat for each prompt
202
+ # Otherwise, we treat inputs as separate seed-prompt pairs
203
+ out_tensor = out_tensor.view(1, -1).repeat(len(prompts), 1)
204
+ if encoder_hidden_states is None:
205
+ if prompts is not None:
206
+ encoder_hidden_states = torch.stack(
207
+ [
208
+ self.mario_lm.prompter.output_hidden(prompt)
209
+ for prompt in prompts
210
+ ]
211
+ )
212
+ else:
213
+ encoder_hidden_states = torch.stack(
214
+ [
215
+ self.mario_lm.prompter(sample_prompt=True)[1]
216
+ for _ in range(seed.shape[0])
217
+ ]
218
+ )
219
+ encoder_hidden_states = encoder_hidden_states.to(
220
+ self.device
221
+ ) # b x 1 x hidden_dim
222
+ encoder_hidden_states = encoder_hidden_states.view(
223
+ out_tensor.shape[0], 1, -1
224
+ )
225
+ if not self.use_tqdm:
226
+ bar = np.arange(num_steps)
227
+ else:
228
+ bar = tqdm(np.arange(num_steps))
229
+ with torch.no_grad():
230
+ for i in bar:
231
+ inp = out_tensor * 1
232
+ if len(out_tensor.shape) > 0 and out_tensor.shape[-1] > context_len:
233
+ diff = inp.shape[-1] % 14 # height of mario level
234
+ ctx = context_len + diff
235
+ inp = inp[:, -ctx:] * 1
236
+ next_tokens, encoder_hidden_states = self.step(
237
+ inp,
238
+ encoder_hidden_states=encoder_hidden_states,
239
+ )
240
+ out_tensor = torch.cat(
241
+ [out_tensor, next_tokens.unsqueeze(-1)], dim=-1
242
+ )
243
+ if self.use_tqdm:
244
+ bar.set_description(
245
+ f"shape: {inp.shape}, {out_tensor.shape} first: {inp[0][0]}, last: {out_tensor[0][-1]}"
246
+ )
247
+ if self.use_tqdm:
248
+ bar.close()
249
+ sample_out = SampleOutput.from_level_predictions(
250
+ out_tensor,
251
+ out_tensor[:, -num_steps:],
252
+ self.mario_lm.tokenizer,
253
+ self.mario_lm.prompter,
254
+ )
255
+ self.mario_lm.train()
256
+ if return_tensor:
257
+ return sample_out, out_tensor
258
+ return sample_out
259
+
260
+ def __call__(self, *args, **kwargs):
261
+ return self.sample(*args, **kwargs)
262
+
263
+
264
+ class BertSampler:
265
+ def __init__(
266
+ self,
267
+ mario_lm: BaseMarioLM,
268
+ temperature: float = 2.0,
269
+ top_k: int = 16,
270
+ context_len: int = 448,
271
+ mask_proportion: float = 0.16,
272
+ ):
273
+ self.mario_lm = mario_lm
274
+ self.temperature = temperature
275
+ self.top_k = top_k
276
+ self.logits_processor = LogitsProcessorList()
277
+ self.logits_warper = LogitsProcessorList(
278
+ [
279
+ TopKLogitsWarper(top_k), # number of characters
280
+ TemperatureLogitsWarper(temperature),
281
+ ]
282
+ )
283
+ self.context_len = context_len
284
+ self.mask_proportion = mask_proportion
285
+ self.mask_portion = int(self.context_len * self.mask_proportion)
286
+ self.mask_portion = self.mask_portion - self.mask_portion % 14 + 14
287
+
288
+ @property
289
+ def device(self) -> torch.device:
290
+ return self.mario_lm.device
291
+
292
+ def get_context(self, input_ids, mask_indices):
293
+ start_idx = mask_indices[0]
294
+ end_idx = mask_indices[-1]
295
+
296
+ if input_ids.shape[-1] <= self.context_len:
297
+ clipped = input_ids.shape[-1] % 14
298
+ input_ids = input_ids[:clipped]
299
+
300
+ portion = (self.context_len - self.mask_portion) / 2
301
+
302
+ remainder = 0
303
+ left = start_idx - portion
304
+ if left < 0:
305
+ remainder = -1 * left
306
+
307
+ right = end_idx + portion + remainder
308
+
309
+ return input_ids[left:right]
310
+
311
+ def sample(
312
+ self,
313
+ seed: Union[torch.Tensor, SampleOutput],
314
+ mask: torch.Tensor,
315
+ return_tensor: bool = False,
316
+ ):
317
+ self.mario_lm.eval()
318
+ mask_indices = mask.nonzero()
319
+ input_ids = seed
320
+ if isinstance(seed, SampleOutput):
321
+ input_ids = seed.level_tensor.to(self.device).squeeze()
322
+
323
+ input_id_list = []
324
+ for i in range(input_ids.shape[0]):
325
+ input_id = input_ids[i]
326
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
327
+ input_id = self.get_context(input_id, mask_index)
328
+ input_id_list.append(input_id)
329
+ input_ids = torch.stack(input_ids, dim=0).to(self.device)
330
+
331
+ attention_mask = torch.ones_like(input_ids).to(seed.device)
332
+
333
+ if len(input_ids.shape) < 2:
334
+ # if we pass in a single seed vector, then we repeat for each prompt
335
+ # Otherwise, we treat inputs as separate seed-prompt pairs
336
+ input_ids = input_ids.view(1, -1)
337
+
338
+ out = self.mario_lm.lm(
339
+ input_ids=input_ids,
340
+ attention_mask=attention_mask,
341
+ token_type_ids=None,
342
+ )
343
+ logits = out.logits.detach()
344
+ if len(logits.shape) == 2:
345
+ logits = logits.view(1, 1, -1)
346
+
347
+ if self.use_argmax:
348
+ tokens = logits.argmax(-1)
349
+ else:
350
+ tokens_scores = self.logits_processor(input_ids, tokens)
351
+ tokens_scores = self.logits_warper(input_ids, tokens_scores)
352
+ probs = torch.nn.functional.softmax(tokens_scores, dim=-1)
353
+ tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
354
+
355
+ out = input_ids.detach()
356
+
357
+ for i in range(input_ids.shape[0]):
358
+ mask_index = mask_indices[mask_indices[:, 0] == i][:, -1]
359
+ out[i, mask_index] = tokens[i, mask_index].detach()
360
+
361
+ sample_out = SampleOutput.from_level_predictions(
362
+ out,
363
+ tokens,
364
+ self.mario_lm.tokenizer,
365
+ self.mario_lm.prompter,
366
+ )
367
+ self.mario_lm.train()
368
+ if return_tensor:
369
+ return sample_out, tokens
370
+ return sample_out
mario_gpt/simulator/PlayAstar.jar ADDED
Binary file (78.1 kB). View file
 
mario_gpt/simulator/PlayLevel.jar ADDED
Binary file (78 kB). View file
 
mario_gpt/simulator/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from mario_gpt.simulator.simulator import Simulator
2
+
3
+ __all__ = ["Simulator"]