multimodalart HF staff commited on
Commit
850b0e4
1 Parent(s): b995aaa

MarioGPT first attempt

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2023 Shyam Sudhakaran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
2
+
3
+ clean-build: ## remove build artifacts
4
+ rm -fr build/
5
+ rm -fr dist/
6
+ rm -fr .eggs/
7
+ find . -name '*.egg-info' -exec rm -fr {} +
8
+ find . -name '*.egg' -exec rm -f {} +
9
+
10
+ clean-pyc: ## remove Python file artifacts
11
+ find . -name '*.pyc' -exec rm -f {} +
12
+ find . -name '*.pyo' -exec rm -f {} +
13
+ find . -name '*~' -exec rm -f {} +
14
+ find . -name '__pycache__' -exec rm -fr {} +
15
+
16
+ clean-test: ## remove test and coverage artifacts
17
+ rm -fr .tox/
18
+ rm -f .coverage
19
+ rm -fr coverage/
20
+ rm -fr .pytest_cache
21
+
22
+ lint: ## check style with flake8
23
+ isort --profile black mario_gpt
24
+ black mario_gpt
25
+ flake8 mario_gpt
26
+
27
+ install: clean lint
28
+ python setup.py install
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from mario_gpt.dataset import MarioDataset
4
+ from mario_gpt.prompter import Prompter
5
+ from mario_gpt.lm import MarioLM
6
+ from mario_gpt.utils import view_level, convert_level_to_png
7
+
8
+ mario_lm = MarioLM()
9
+
10
+ device = torch.device('cuda')
11
+ mario_lm = mario_lm.to(device)
12
+ TILE_DIR = "data/tiles"
13
+
14
+ def update(prompt, progress=gr.Progress(track_tqdm=True)):
15
+ prompts = [prompt]
16
+ generated_level = mario_lm.sample(
17
+ prompts=prompts,
18
+ num_steps=1399,
19
+ temperature=2.0,
20
+ use_tqdm=True
21
+ )
22
+ img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
23
+ return img
24
+
25
+ with gr.Blocks() as demo:
26
+ with gr.Row():
27
+ prompt = gr.Textbox(label="Enter your MarioGPT prompt")
28
+ level_image = gr.Image()
29
+ btn = gr.Button("Generate level")
30
+ btn.click(fn=update, inputs=prompt, outputs=level_image)
31
+ pass
32
+ demo.launch()
data/tiles/N.png ADDED
data/tiles/Y.png ADDED
data/tiles/cannon_bottom.png ADDED
data/tiles/cannon_top.png ADDED
data/tiles/flying_koopa.png ADDED
data/tiles/ki-background.png ADDED
data/tiles/ki-door.png ADDED
data/tiles/ki-hazard.png ADDED
data/tiles/ki-moving-platform.png ADDED
data/tiles/ki-passable.png ADDED
data/tiles/ki-path.png ADDED
data/tiles/ki-unpassable.png ADDED
data/tiles/mm-CMM.png ADDED
data/tiles/mm-DMM.png ADDED
data/tiles/mm-HMM.png ADDED
data/tiles/mm-LMM.png ADDED
data/tiles/mm-MMM.png ADDED
data/tiles/mm-TMM.png ADDED
data/tiles/mma_tiles.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6d58bb3228bcd3c653c4a58b69044588ffd6e5e4c946a860497a39d84eb60b8
3
+ size 6586
data/tiles/plant.png ADDED
data/tiles/smb-background.png ADDED
data/tiles/smb-breakable.png ADDED
data/tiles/smb-coin.png ADDED
data/tiles/smb-enemy.png ADDED
data/tiles/smb-path.png ADDED
data/tiles/smb-question.png ADDED
data/tiles/smb-tube-lower-left.png ADDED
data/tiles/smb-tube-lower-right.png ADDED
data/tiles/smb-tube-top-left.png ADDED
data/tiles/smb-tube-top-right.png ADDED
data/tiles/smb-unpassable.png ADDED
data/tiles/smb_enemies_sheet.png ADDED
data/tiles/tile004 (1).png ADDED
data/tiles/tile004 (2).png ADDED
data/tiles/tile004.png ADDED
mario_gpt/__init__.py ADDED
File without changes
mario_gpt/dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
9
+
10
+ from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS
11
+
12
+ DEFAULT_MODEL = "distilgpt2"
13
+
14
+
15
+ def split_given_size(a, size):
16
+ return np.split(a, np.arange(size, len(a), size))
17
+
18
+
19
+ def flip_and_transpose(arr: np.array, flip_first: bool = False):
20
+ if arr.shape[-1] > 1:
21
+ if flip_first:
22
+ return np.flip(arr, -1).transpose()
23
+ return np.flip(arr.transpose(), -1)
24
+ return arr
25
+
26
+
27
+ def join_list_of_list(str_lists):
28
+ return ["".join(s) for s in str_lists]
29
+
30
+
31
+ def characterize(str_lists):
32
+ return [list(s) for s in str_lists]
33
+
34
+
35
+ class MarioDataset(Dataset):
36
+ def __init__(
37
+ self,
38
+ tokenizer: Optional[PreTrainedTokenizer] = None,
39
+ level_string: Optional[str] = None,
40
+ context_len: int = 700,
41
+ height: int = 14,
42
+ remove_start_end_tokens: bool = False,
43
+ sample_all_indices: bool = False,
44
+ ):
45
+ if level_string is None:
46
+ print(
47
+ "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
48
+ )
49
+ level_string = FULL_LEVEL_STR_WITH_PATHS
50
+ elif ".txt" in level_string:
51
+ with open(level_string, "r") as file:
52
+ level_string = file.read()
53
+
54
+ self.character_set = set(level_string)
55
+ if "\n" in self.character_set:
56
+ self.character_set.remove("\n")
57
+ self.vocab_size = len(self.character_set)
58
+ self.sample_all_indices = sample_all_indices
59
+
60
+ def get_training_corpus():
61
+ yield list(level_string)
62
+
63
+ if tokenizer is None:
64
+ tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
65
+
66
+ self.tokenizer = tokenizer
67
+ if getattr(tokenizer, "train_new_from_iterator", None) is not None:
68
+ self.tokenizer = tokenizer.train_new_from_iterator(
69
+ get_training_corpus(), 52000
70
+ )
71
+ elif getattr(tokenizer, "train_from_iterator", None) is not None:
72
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
73
+ self.tokenizer = self.tokenizer.train_new_from_iterator(
74
+ get_training_corpus(), self.vocab_size
75
+ )
76
+ self.context_len = context_len
77
+ self.height = height
78
+
79
+ x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
80
+ self.input_ids = x["input_ids"].squeeze()
81
+ self.attention_masks = x["attention_mask"].squeeze()
82
+ if remove_start_end_tokens:
83
+ self.input_ids = self.input_ids[1:-1]
84
+ self.attention_masks = self.attention_masks[1:-1]
85
+
86
+ self.indices = self.generate_indices()
87
+
88
+ self.unique_tokens, self.unique_counts = self.input_ids.unique(
89
+ return_counts=True
90
+ )
91
+ self.weighted_unique_counts = (
92
+ 1.0 / self.unique_counts / torch.sum(self.unique_counts)
93
+ )
94
+
95
+ self.token_dict = {}
96
+ string_tokens = list(self.tokenizer.decode(self.unique_tokens))
97
+ for int_token, string_token in zip(self.unique_tokens, string_tokens):
98
+ self.token_dict[string_token] = int_token
99
+
100
+ def convert_level_to_tensor(self, level: List[str]):
101
+ str_arr = flip_and_transpose(np.array(characterize(level)))
102
+ str_arr = "".join(join_list_of_list(str_arr))
103
+
104
+ x = self.tokenizer(str_arr, return_tensors="pt")
105
+ return x, str_arr
106
+
107
+ def __len__(self):
108
+ return self.indices.shape[0]
109
+
110
+ def __getitem__(self, idx):
111
+ indices = self.indices[idx]
112
+ return self.input_ids[indices], self.attention_masks[indices]
113
+
114
+ def generate_indices(self):
115
+ out = []
116
+ for idx in range(self.input_ids.shape[0] - self.context_len):
117
+ if idx % self.height == 0 or self.sample_all_indices:
118
+ arange = torch.arange(idx, idx + self.context_len)
119
+ out.append(arange)
120
+ return torch.stack(out)
121
+
122
+ def sample_indices(self, batch_size):
123
+ out = []
124
+ for _ in range(batch_size):
125
+ start_idx = np.random.randint(0, self.__len__() - self.context_len)
126
+ indices = torch.arange(start_idx, start_idx + self.context_len)
127
+ out.append(indices)
128
+ return torch.stack(out)
129
+
130
+ def __str__(self):
131
+ str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
132
+ string = "\n".join(
133
+ join_list_of_list(flip_and_transpose(np.array(str_list), True))
134
+ )
135
+ return string
136
+
137
+ def generate_mask(self, mask_len: int, batch_size: int = 1):
138
+ mask_token = self.tokenizer("<mask>").input_ids[1]
139
+ ones = torch.ones((batch_size, mask_len))
140
+ return ones * mask_token
141
+
142
+ def apply_mask(self, level, masked_indices, mask=None):
143
+ if len(level.shape) == 1:
144
+ level = level.unsqueeze(0)
145
+ batch_size = level.shape[0]
146
+ mask_len = masked_indices.shape[-1]
147
+ if mask is None:
148
+ mask = self.generate_mask(mask_len, batch_size)
149
+ mask = mask.long().to(level.device)
150
+ masked_level = level * torch.ones_like(level).to(level.device)
151
+ masked_level[:, masked_indices] = mask
152
+ return masked_level
mario_gpt/level.py ADDED
The diff for this file is too large to render. See raw diff
mario_gpt/lm.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+ from transformers import (
7
+ AutoModelWithLMHead,
8
+ AutoTokenizer,
9
+ GPT2Model,
10
+ GPT2Tokenizer,
11
+ LogitsProcessorList,
12
+ PreTrainedModel,
13
+ PreTrainedTokenizer,
14
+ TemperatureLogitsWarper,
15
+ TopKLogitsWarper,
16
+ )
17
+
18
+ from mario_gpt.prompter import Prompter
19
+
20
+ PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
21
+
22
+
23
+ class MarioLM:
24
+ def __init__(
25
+ self,
26
+ lm: Optional[PreTrainedModel] = None,
27
+ tokenizer: Optional[PreTrainedTokenizer] = None,
28
+ context_len: int = 700,
29
+ prompter: Optional[Prompter] = None,
30
+ ):
31
+ self.context_len = context_len
32
+ self.lm = lm
33
+
34
+ if lm is None:
35
+ self.lm = self.load_pretrained_lm()
36
+
37
+ self.tokenizer = tokenizer
38
+ if tokenizer is None:
39
+ self.tokenizer = self.load_pretrained_tokenizer()
40
+
41
+ self.prompter = prompter
42
+ if prompter is None:
43
+ self.prompter = Prompter(self.tokenizer)
44
+
45
+ @property
46
+ def device(self):
47
+ return self.lm.device
48
+
49
+ def to(self, device: torch.device):
50
+ self.lm = self.lm.to(device)
51
+ return self
52
+
53
+ def load_pretrained_lm(self) -> GPT2Model:
54
+ print(f"Using {PRETRAINED_MODEL_PATH} model")
55
+ return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
56
+
57
+ def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
58
+ print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
59
+ return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
60
+
61
+ def sample_step(
62
+ self,
63
+ seed: torch.Tensor,
64
+ encoder_hidden_states: torch.Tensor,
65
+ temperature: float = 2.0,
66
+ ):
67
+ lm = self.lm
68
+ logits_processor = LogitsProcessorList()
69
+ logits_warper = LogitsProcessorList(
70
+ [
71
+ TopKLogitsWarper(16), # number of characters
72
+ TemperatureLogitsWarper(temperature),
73
+ ]
74
+ )
75
+ with torch.no_grad():
76
+ attention_mask = torch.ones_like(seed).to(seed.device)
77
+ input_ids = seed
78
+ out = lm(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ encoder_hidden_states=encoder_hidden_states,
82
+ token_type_ids=None,
83
+ )
84
+ logits = out.logits.detach()
85
+ if len(logits.shape) == 2:
86
+ logits = logits.view(1, 1, -1)
87
+ next_token_logits = logits[:, -1, :]
88
+
89
+ next_token_scores = logits_processor(input_ids, next_token_logits)
90
+ next_token_scores = logits_warper(input_ids, next_token_scores)
91
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
92
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
93
+ return next_tokens, encoder_hidden_states
94
+
95
+ def sample(
96
+ self,
97
+ seed: Optional[torch.Tensor] = None,
98
+ prompts: Optional[List[str]] = None,
99
+ num_steps: int = 1,
100
+ temperature: float = 2.0,
101
+ encoder_hidden_states: torch.Tensor = None,
102
+ use_tqdm: bool = False,
103
+ ):
104
+ context_len = self.context_len - 28
105
+ self.lm.eval()
106
+ with torch.no_grad():
107
+ if seed is None:
108
+ seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
109
+ out = seed.to(self.device)
110
+ if encoder_hidden_states is None:
111
+ if prompts is not None:
112
+ encoder_hidden_states = torch.stack(
113
+ [self.prompter.output_hidden(prompt) for prompt in prompts]
114
+ )
115
+ else:
116
+ encoder_hidden_states = torch.stack(
117
+ [
118
+ self.prompter(sample_prompt=True)[1]
119
+ for _ in range(seed.shape[0])
120
+ ]
121
+ )
122
+ encoder_hidden_states = encoder_hidden_states.to(
123
+ self.device
124
+ ) # b x 1 x hidden_dim
125
+ encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
126
+ if not use_tqdm:
127
+ bar = np.arange(num_steps)
128
+ else:
129
+ bar = tqdm(np.arange(num_steps))
130
+ with torch.no_grad():
131
+ for i in bar:
132
+ inp = out * 1
133
+ if len(out.shape) > 0 and out.shape[-1] > context_len:
134
+ diff = inp.shape[-1] % 14 # height of mario level
135
+ ctx = context_len + diff
136
+ inp = inp[:, -ctx:] * 1
137
+ next_tokens, encoder_hidden_states = self.sample_step(
138
+ inp,
139
+ encoder_hidden_states=encoder_hidden_states,
140
+ temperature=temperature,
141
+ )
142
+ out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
143
+ if use_tqdm:
144
+ bar.set_description(
145
+ f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
146
+ )
147
+ if use_tqdm:
148
+ bar.close()
149
+ self.lm.train()
150
+ return out
mario_gpt/prompter.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from scipy import stats
9
+ from transformers import pipeline
10
+
11
+ from mario_gpt.dataset import MarioDataset
12
+ from mario_gpt.utils import view_level
13
+
14
+ STATISTICS = {
15
+ "enemy": np.array([1.0, 3.0, 7.0]),
16
+ "pipe": np.array([0.0, 2.0, 5.0]),
17
+ "block": np.array([50.0, 75.0, 176.0]),
18
+ }
19
+
20
+ FEATURE_EXTRACTION_MODEL = "facebook/bart-base"
21
+
22
+
23
+ class Prompter:
24
+ def __init__(
25
+ self,
26
+ level_tokenizer,
27
+ prompter_model: str = FEATURE_EXTRACTION_MODEL,
28
+ use_raw_counts: bool = False,
29
+ statistics: Optional[Dict[str, Any]] = None,
30
+ ):
31
+ self.prompter_model = prompter_model
32
+ self.feature_extraction = pipeline(
33
+ "feature-extraction",
34
+ model=prompter_model,
35
+ tokenizer=prompter_model,
36
+ framework="pt",
37
+ )
38
+
39
+ self.level_tokenizer = level_tokenizer
40
+
41
+ self.use_raw_counts = use_raw_counts
42
+ self.statistics = statistics
43
+ if statistics is None:
44
+ self.statistics = STATISTICS
45
+
46
+ @property
47
+ def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
48
+ thresholds = self.statistics["pipe"]
49
+ keywords = ["no", "little", "some", "many"]
50
+ return thresholds, keywords
51
+
52
+ @property
53
+ def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
54
+ thresholds = self.statistics["enemy"]
55
+ keywords = ["no", "little", "some", "many"]
56
+ return thresholds, keywords
57
+
58
+ @property
59
+ def block_thresholds(self) -> Tuple[List[int], List[str]]:
60
+ thresholds = self.statistics["block"]
61
+ keywords = ["little", "little", "some", "many"]
62
+ return thresholds, keywords
63
+
64
+ def count_pipes(self, flattened_level: str) -> int:
65
+ return flattened_level.count("<>")
66
+
67
+ def count_enemies(self, flattened_level: str) -> int:
68
+ return flattened_level.count("E") + flattened_level.count("B")
69
+
70
+ def count_blocks(self, flattened_level: str) -> int:
71
+ return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])
72
+
73
+ def _flatten_level(self, string_level: List[str]) -> str:
74
+ return "".join(string_level)
75
+
76
+ def pipe_prompt(self, flattened_level: str, level: str) -> str:
77
+ count = self.count_pipes(flattened_level)
78
+ keyword = f"{count}"
79
+ if not self.use_raw_counts:
80
+ thresholds, keywords = self.pipe_thresholds
81
+ threshold = np.digitize(count, thresholds, right=True)
82
+ keyword = keywords[threshold]
83
+ return f"{keyword} pipes", keyword
84
+
85
+ def enemy_prompt(self, flattened_level: str, level: str) -> str:
86
+ count = self.count_enemies(flattened_level)
87
+ keyword = f"{count}"
88
+ if not self.use_raw_counts:
89
+ thresholds, keywords = self.enemy_thresholds
90
+ threshold = np.digitize(count, thresholds, right=True)
91
+ keyword = keywords[threshold]
92
+ return f"{keyword} enemies", keyword
93
+
94
+ def block_prompt(self, flattened_level: str, level: str) -> str:
95
+ count = self.count_blocks(flattened_level)
96
+ keyword = f"{count}"
97
+ if not self.use_raw_counts:
98
+ thresholds, keywords = self.block_thresholds
99
+ threshold = np.digitize(count, thresholds, right=True)
100
+ keyword = keywords[threshold]
101
+ return f"{keyword} blocks", keyword
102
+
103
+ def elevation_prompt(self, flattened_level: str, level: str):
104
+ top_levels = level[:6] # elevation 8 and up
105
+ for t in top_levels:
106
+ if "X" in t or "<" in t or ">" in t:
107
+ return "high elevation", "high"
108
+ return "low elevation", "low"
109
+
110
+ def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
111
+ # Reducing along the first dimension to get a 768 dimensional array
112
+ return (
113
+ self.feature_extraction(prompt, return_tensors="pt")[0]
114
+ .mean(0)
115
+ .to(device)
116
+ .view(1, -1)
117
+ )
118
+
119
+ def dataset_statistics(self, dataset: MarioDataset):
120
+ enemy_counts = []
121
+ pipe_counts = []
122
+ block_counts = []
123
+ for i in range(len(dataset)):
124
+ level, _ = dataset[i]
125
+ str_level = self._flatten_level(view_level(level, dataset.tokenizer))
126
+
127
+ enemy_count = self.count_enemies(str_level)
128
+ pipe_count = self.count_pipes(str_level)
129
+ block_count = self.count_blocks(str_level)
130
+
131
+ enemy_counts.append(enemy_count)
132
+ pipe_counts.append(pipe_count)
133
+ block_counts.append(block_count)
134
+ d = {"enemy": {}, "pipe": {}, "block": {}}
135
+
136
+ d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
137
+ d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
138
+ d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
139
+ return d
140
+
141
+ def __call__(
142
+ self, level: torch.Tensor = None, sample_prompt: bool = False
143
+ ) -> Union[str, torch.Tensor]:
144
+ device: torch.device = torch.device("cpu")
145
+ if not sample_prompt:
146
+ if level is None:
147
+ raise ValueError("Level must be provided if sample_prompt is not true!")
148
+ str_level = view_level(level, self.level_tokenizer)
149
+ flattened_level = self._flatten_level(str_level)
150
+
151
+ pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
152
+ enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
153
+ block_prompt, _ = self.block_prompt(flattened_level, str_level)
154
+ elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
155
+ device = level.device
156
+ else:
157
+ str_level = None
158
+ pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
159
+ enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
160
+ block_prompt = (
161
+ random.choice(["little", "little", "some", "many"]) + " blocks"
162
+ ) # levels always have blocks
163
+ elevation_prompt = (
164
+ random.choice(["low", "high"]) + " elevation"
165
+ ) # levels always have blocks
166
+
167
+ prompt_dict = {
168
+ "pipe": pipe_prompt,
169
+ "enemy": enemy_prompt,
170
+ "block": block_prompt,
171
+ "elevation_prompt": elevation_prompt,
172
+ }
173
+ prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
174
+ hidden = self.output_hidden(prompt, device=device)
175
+ return prompt, hidden, prompt_dict, str_level
mario_gpt/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def characterize(str_lists):
9
+ return [list(s[::-1]) for s in str_lists]
10
+
11
+
12
+ def join_list_of_list(str_lists):
13
+ return ["".join(s) for s in str_lists]
14
+
15
+
16
+ def view_level(level_tokens, tokenizer):
17
+ str_list = [
18
+ s.replace("<mask>", "Y")
19
+ for s in tokenizer.batch_decode(level_tokens.detach().cpu().view(-1, 14))
20
+ ]
21
+ return join_list_of_list(np.array(characterize(str_list)).T)
22
+
23
+
24
+ def is_flying_enemy(array, row, col):
25
+ num_rows = array.shape[0]
26
+ if row == num_rows - 1:
27
+ return False
28
+ below = array[row + 1][col]
29
+ return below == "-"
30
+
31
+
32
+ def char_array_to_image(array, chars2pngs):
33
+ """
34
+ Convert a 16-by-16 array of integers into a PIL.Image object
35
+ param: array: a 16-by-16 array of integers
36
+ """
37
+ image = Image.new("RGB", (array.shape[1] * 16, array.shape[0] * 16))
38
+ for row in range(array.shape[0]):
39
+ for col, char in enumerate(array[row]):
40
+ value = chars2pngs["-"]
41
+ # if char == "E":
42
+ # if is_flying_enemy(array, row, col):
43
+ # char = "F"
44
+ if char in chars2pngs:
45
+ value = chars2pngs[char]
46
+ else:
47
+ print(f"REPLACING {value}", (col, row))
48
+
49
+ image.paste(value, (col * 16, row * 16))
50
+ return image
51
+
52
+
53
+ def convert_level_to_png(
54
+ level: Union[str, torch.Tensor], tiles_dir: str, tokenizer=None
55
+ ):
56
+ if isinstance(level, torch.Tensor):
57
+ level = view_level(level, tokenizer)
58
+ chars2pngs = {
59
+ "-": Image.open(f"{tiles_dir}/smb-background.png"),
60
+ "X": Image.open(f"{tiles_dir}/smb-unpassable.png"),
61
+ "S": Image.open(f"{tiles_dir}/smb-breakable.png"),
62
+ "?": Image.open(f"{tiles_dir}/smb-question.png"),
63
+ "Q": Image.open(f"{tiles_dir}/smb-question.png"),
64
+ "o": Image.open(f"{tiles_dir}/smb-coin.png"),
65
+ "E": Image.open(f"{tiles_dir}/smb-enemy.png"),
66
+ "<": Image.open(f"{tiles_dir}/smb-tube-top-left.png"),
67
+ ">": Image.open(f"{tiles_dir}/smb-tube-top-right.png"),
68
+ "[": Image.open(f"{tiles_dir}/smb-tube-lower-left.png"),
69
+ "]": Image.open(f"{tiles_dir}/smb-tube-lower-right.png"),
70
+ "x": Image.open(f"{tiles_dir}/smb-path.png"), # self-created
71
+ "Y": Image.open(f"{tiles_dir}/Y.png"), # self-created
72
+ "N": Image.open(f"{tiles_dir}/N.png"), # self-created
73
+ "B": Image.open(f"{tiles_dir}/cannon_top.png"),
74
+ "b": Image.open(f"{tiles_dir}/cannon_bottom.png"),
75
+ "F": Image.open(f"{tiles_dir}/flying_koopa.png"),
76
+ }
77
+ levels = [list(s) for s in level]
78
+ arr = np.array(levels)
79
+ return char_array_to_image(arr, chars2pngs), arr, level
80
+
81
+
82
+ TOKENS = [
83
+ "-",
84
+ "X",
85
+ "S",
86
+ "?",
87
+ "Q",
88
+ "o",
89
+ "E",
90
+ "<",
91
+ ">",
92
+ "[",
93
+ "]",
94
+ "x",
95
+ "Y",
96
+ "N",
97
+ "B",
98
+ "b",
99
+ ]
notebooks/Sampling.ipynb ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "df85b023-cdb5-498e-8373-0fd5b7c31853",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Load Stuff"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "895fc851-817b-4c23-baf4-72cf73238781",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import torch\n",
19
+ "from mario_gpt.dataset import MarioDataset\n",
20
+ "from mario_gpt.prompter import Prompter\n",
21
+ "from mario_gpt.lm import MarioLM\n",
22
+ "from mario_gpt.utils import view_level, convert_level_to_png"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "28c11f07-b604-4603-8fe3-d53874ba02a8",
28
+ "metadata": {},
29
+ "source": [
30
+ "### Load Model"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "id": "6f656e57-24a6-4624-b6ed-8aa871581007",
37
+ "metadata": {},
38
+ "outputs": [
39
+ {
40
+ "name": "stdout",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Using shyamsn97/Mario-GPT2-700-context-length model\n"
44
+ ]
45
+ },
46
+ {
47
+ "name": "stderr",
48
+ "output_type": "stream",
49
+ "text": [
50
+ "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:1177: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
51
+ " warnings.warn(\n"
52
+ ]
53
+ },
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "Using shyamsn97/Mario-GPT2-700-context-length tokenizer\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "mario_lm = MarioLM()"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 3,
69
+ "id": "1a60f6ed-42be-4d17-af15-151fa24e0f91",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "TILE_DIR = \"../data/tiles\""
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "id": "a7d7bd55-14d4-45a3-9539-c7c385f63070",
79
+ "metadata": {},
80
+ "source": [
81
+ "### Load Dataset (Optional)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 4,
87
+ "id": "6c0840d0-ea5b-4111-9198-6b5a716083bd",
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS...\n",
95
+ "\n",
96
+ "\n",
97
+ "\n"
98
+ ]
99
+ },
100
+ {
101
+ "name": "stderr",
102
+ "output_type": "stream",
103
+ "text": [
104
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (102116 > 1024). Running this sequence through the model will result in indexing errors\n"
105
+ ]
106
+ }
107
+ ],
108
+ "source": [
109
+ "dataset = MarioDataset(mario_lm.tokenizer)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "c80a131f-c68f-475d-ab24-acd3da814c39",
115
+ "metadata": {},
116
+ "source": [
117
+ "#### View string representation of level"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 5,
123
+ "id": "2bdab45e-58cb-4bcb-8d6e-dee6c946d6fd",
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "data": {
128
+ "text/plain": [
129
+ "['--------------------------------------------------',\n",
130
+ " '--------------------------------------------------',\n",
131
+ " '--------------------------------------------------',\n",
132
+ " '--------------------------------------------------',\n",
133
+ " '-------------------------------------------------o',\n",
134
+ " '--------XSSSSS---------------------------------SSS',\n",
135
+ " '--------X-----------------------------------------',\n",
136
+ " '--------X-----------------------------------------',\n",
137
+ " '-------EX--E-X---------------xxxx-?-----------xxxx',\n",
138
+ " '--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',\n",
139
+ " '---------------------------xx-[]--x---------xx----',\n",
140
+ " '--------------------------xx--[]---x-------xx-----',\n",
141
+ " 'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',\n",
142
+ " 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']"
143
+ ]
144
+ },
145
+ "execution_count": 5,
146
+ "metadata": {},
147
+ "output_type": "execute_result"
148
+ }
149
+ ],
150
+ "source": [
151
+ "view_level(dataset.input_ids[:700], mario_lm.tokenizer)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "99be5b3a-c968-4fbd-a51a-f623003072c0",
157
+ "metadata": {},
158
+ "source": [
159
+ "#### Image"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "id": "d5614fc2-59bc-40ee-a92a-0cfd971e1ad3",
165
+ "metadata": {},
166
+ "source": [
167
+ "##### Previewing the first 50 columns of the dataset"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 6,
173
+ "id": "0d6a3bf3-d050-4760-a48e-8b8655142c67",
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/Pillow-9.1.1-py3.9-linux-x86_64.egg/PIL/Image.py:992: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
181
+ " warnings.warn(\n"
182
+ ]
183
+ },
184
+ {
185
+ "data": {
186
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAyAAAADgCAIAAAB0EpUWAAAYPUlEQVR4nO3dT2wcVZ7A8fe6q9OOsWNv4mwDQXHsvRhWyybZbAA7M4qGS9QoWhGFA1oJgmwNITkwBywuHo1G8S0ckBgFZXGCJZgDYsgcMmmNOIyi7MYBokSMUFgIKAECJuGPkzh2/Kf/1B6abZz+U1Wv+3X5VdX3o9EInG93qitdzo/X5Sq551BWeCalsG3vOT09PT09PT19FPuYQi7Unp2enp6enp6ePpq92oAFAAAAVwxYAAAAmjFgAQAAaMaABQAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgWUxKtQfQ09PT09PT09M7s5KWsG1RsEWu4H6zaCkFPT09PT09PT29cy/3HMq6VHc+wPVJ6enp6enp6ekj3qudg6X07PT09PT09PT00ew5yR0AAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANAsJqXaA+jp6enp6enp6Z1ZSUvYtijYIldwv1m0lIKenp6enp6ent65l3sOZV2qOx/g+qT09PT09PT09BHv1c7BUnp2enp6enp6evpo9pzkDgAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAaBaTUu0B9PT09PT09PT0zqykJWxbFGyRK7jfLFpKQU9PT09PT09P79zLPYeyLtWdD3B9Unp6enp6enr6iPdq52ApPTs9PT09PT09fTR7TnIHAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0i0mp9gB6enp6enp6enpnVtISti0KtsgV3G8WLaWgp6enp6enp6d37uWeQ1mX6s4HuD4pPT09PT09PX3I+gdmXlZ4gOo5WEpbQ09PT09PT08fwV4IYSk/AgAAINr+8x8Olf753qHLQoiDBw8uDfgpQgAAAAVLpyshxORYjxBieHh46RcZsAAAAJRljl3OHLtc61cZsAAAALwqLl9ljl1O7+pJ7+opfj5YuYjFgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAePXH6/uEEMUrYN07dLl4BazKu+VYBt6wmp6enp6enp7eqL5ScbSqxUpawrZFwRa5gvvvJKWgp6enp6enp49av9Qfr+9zvdmz3HMo6/Ksd/4GRo2TIeiPPnT2yEj/4OiEe7y5Xwhhn3cvw+SZ9/7de2zgny89PT09vT/90YfOKjyg+dQGLGj3+sNnhRCuM1ZxuhJCjKWr/OpQJrRf5/0JAPBi218SlV9cxr+/OMndCIOjE0dG+mv9amm6AgAAgcAK1jIrrmAV1ZqxhjJ+bY15eH8CALwY31dlBWsZWcu9AVBTWoo08OO8Znz9f6okAACUM+3vRz4iBAAA0IwBCwAAQDPOwVpmnIPljPcnAMALzsFCQ0z4XNnPrwMA4IVpf3/xESEAAIBmDFgAAACacQ7WMiudg+VwNdGxdHQ/LOP9CQDwgnOwUIXztdqLH/EWZyzTrvPR7K9zHSwAgBem/f3IR4TLz8udcGr9cQIAAAPFpFR7AL3m3vN9BqM5Yxn350VPT09Pb2RvGrn3taxti4ItcgVh2261FElL0GvsX3/OrM+MTcP7k56enp4+iH+fqp3kLqX7i6RX6o8+dPbISP/g6IR7vLlfKJ70beDrpaenp6enj0Kvdg6W0rPTe+wHRydqXcO9xPsniY1vDz09PT09PX2DPSe5G8F5xqpvugIAAMuFyzSYwss6FgAACARWsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAM66DZQqHq4mOpcVQxs9tAQAADWEFywjO12ofyoixtG/bAgAAGsWAtfy83AmHGQsAgACJSan2AHrNvef7DBZnLOO2n56enp6enr6ClbSEbYuCLXIF95tFSyno9fZKhjJi7+NmbT89PT09PT19ZW/NZ12iMvR6e/v8xJGR/sHRCdeyuNZl2vbT09PT09PTV1I7B8t1ZKOvox8cnTgy4vJBofdPEhvfHnp6enp6evoGe05yN4LzjFXfdAUAAJYL18EyhZd1LAAAEAisYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZlwHyxQOVxMdS4uhjJ/bAgAAGsIKlhGcr9U+lBFjad+2BQAANIoBa/l5uRMOMxYAAAESk1LtAfSae8/3GSzOWMZtPz09PT09PX0FK2kJ2xYFW+QK7jeLllLQ6+2VDGXE3sfN2n56enp6enr6yt6az7pEZeid/dONl8u+surk70v/PL39d2W/ap+fODLSPzg64frMxbUu07b/wl2/UXr+V//trNLr3XNI7QWY9n6gp49yv+XPCe/x4OjpZ97b2tTtoaf3s1f7KULXkY2+zNLppPivlTPK4Kj7jOX9k0Tn7VHtvWy/6vOb/Hrp6en19o/t6vHYy80DSv9BZebrpacv9Zzk3kRl04nDF4szR63nqW/aaJz37Vdl5usFfBYv5NbMXl01P7X69rXUrStrZq8mc/MB6j26Z+hygwEgzHv/u/ZcB6tZHAYRh3WsJm+UAtXtV2Xa6wX8t/r2tf2ZJ7/p2tgxO5maunCz7b6TD+57vzedjScD0Xv07dhPi1gnjt0xSJUWt0oB4MC0979rzwpWU6w6+ftnT8w8e2Km8peKX9eyDtQ8Qd9+IBAS+YXuyTObLr6xkGj7oO/plsXp3aee77t6Nii9kuJ0NZT56X+iYt4CnJn2/nftGbD0Kw4fhx9rs9/7a9mM8uyJGfu9vx5+rE1o+qytGYK+/UCA5Kz4xfXpV3aMv7N1ePzRw4nc7Uc+fStAfR2+HespLVnx4SCUmPb+d+75iLDpqq4DBUjQtx8wWdZqubJ2oxCiIOPTK7tutPd0zE4GqAf8ZNr737lnBUuzynWd4pJPcUHINV52Qd9+IFgSufnu786tmp+6a3E6Nf1F561LU+3dAeq9e2xXT+meqnw4iPqY9v537lnBaq6yoaRyRjFc0LcfMFw8n98weeqJMwe+7+zd/Nnbt1rXne/dGaBeSfGTwdLJWIAq097/zj0DVhP914HdDv9qvqBvP2C+rLXiq7sHUtc/uf/LzExr6t0tL364fnuA+kbww4NQZdr737lnwNJsevvvln5w9uvf/qksWDqmNH6xA+2Cvv1AsOTjyc/XbTu+aX/H3I8zyc5FtwsimNbX4bFdPZzbjvqY9v537hmw9CuNHS/84l8qf/XXv/3TS//9UeXXHa6uOZb2dTm9vu1XZc7rBZZdXlpTrang9kpOHLs8lla4wjtQxrT3f62ek9yXwfDw8PDw8NKvOF+7fCgjxtJN3iYVlduvKlivF2gSW8YWEu25mNcb9pnWN4JFLKgy7f3v2jNgNcvw8LB8eEfZF+XDO0pfLM0oXu4M4//M4X37VZn5egH/zSXazvU99fWaBwLae1Q2Sy39ccKqAVCVae9/114+82pW6RaGUqrd8jBq/T/Pvlz8h9L8UfrJu8p55eDBg0pjylhanN7Z3D8v1e3/uO03Ss8/vk/hP459eL309PTN67f9JeH9o8B7hi7z9xF9mHoraQnbFgVb5Aruj5RS0Dv3JaXhST68o+rVDQ4ePOjydBWGMmLv42Ztf93P74UPr5eenr55/VBGiIzC6pRp209P30gv9xzKulR3PsCo8TAE/dGHzh4Z6R8cnXCPN/cLIezz7mUjlG7APDh6WgjFoUkIpddr2vtz4LjCCtzg6OkjIwOB7p95b6v33sDjK+j9688193Qo044venqTe9Xv/2o/Rai0NfQe+8HRCdeZo3TeUtUBqNYZS/V93fuSvtw8oPr89vkJpderxJ8/L9X9E+he6S9gM4+voPcvfaz2KO9eUDzVysz9Q0/vZ6/0/ZPLNBjBeeaob9poxD1Dl52vAegaODPt9apS3T9B7+FRvJDrnPshG19hFbKJ/GIulphJdi5YLY30o3+74yEjvyp/krLAT6qvF1iqGceLD7x//2TAMkVx5nDNal0gStfXi0rvnrJbhpWG91JQ3/MLz6/XTN73Tzh6eLT69rX9mSe/6drYMTuZmrpws+2+kw/ue783na1xuUIvfWmichikio3zpDXyKzH6t5//XwjxgtqLq3P7gVqacbz4wPv3TwasgCl99Kb3Y8Gqym4ZNpYWJ45dLlsgrXt7QsDL/glTD1eJ/EL35Jl/nProfzf8xwd9T//rpT/vPvX8zda1H63bpqVvRNl0pYWf24/wMfl4ceXl+yfXwYK7b8d6SiM5V6yppLp/gt7DQc6KX1yffmXH+Dtbh8cfPZzI3X7k07c09nVbuoKlkW/bj1Ay9njxzuH7JwMWAGiTtVqurN0ohCjI+PTKrhvtPR2zkxr7ujVjBUv4uP0IJWOPFy34iDBglp7b1IzzsZZ+nLf0nqxlHzbXvT2Do9WbIPKyf8LUw4tEbr77u3Or5qfyMSs1/UXnrUuX7v2Fxr5uledgaeHb9iOUjD1eXHn5/smAFTBNPe+qckIqrnyWfdjcyPaEjOv+CVkPV/F8fsPkqSfOHPi+s3fzZ2/fal13vnenxr5uTVrB8m37EUrGHi9euH7/ZMCCAn64zJnq/gl6j0pZa8VXdw+krn9y/5eZmdbUu1te/HD9do193Zq0guXb9iOUjD1e6lD5/ZMBC+6WroWikur+CXoPB/l48vN1245v2t8x9+NMsnPR7QfIVfu6NWkFy7ftRygZe7x45/D9kwHLFA5X1xxL/7z86M91sMqcOHa51hXA635+j683EBz2Tyh7uMpLa6o11by+Dk1awSryYfsRYgYeL945fP9kwDKC87XLi6c0FWcOP6+DVabq5Wvr2x7vrzdAVK9+HvQelWwZW0i052Jeb1im2jeiGdOVn9uP8DH5eFFV9fsnl2lYfl7uDOPz5TrLFjwf29VTNu408omSga9Xler+CXoPj+YSbef6nvp6jdeb/Kn2jWjGdbD83H6Ej8nHiwPv3z8tA29YHa3e8333/Jw5KifxWvdaUaX6ek8b9udVpLp/gtsbd7yY3d9Y2fXmwAGNvai4JXPV+9t4uenNCxX/X7k9rhp8varPTx/uXvvxYtr3f7n3taxti4ItcgX330lKkbQEvcb+9edMXO00h2nvz6j9eZm2/6PWH91rCSGklC6pOtu2hRDPjeWMer309Cb3qt//5Z5DWYXa7HGYnp6ePkx96Rv6Sx//9MXiB3xl51GVzqyq9fFf2dlXpVUxvv/T0zevVzsHS3UxjZ6enp7en96BlrOvTHu99PSG9/wUIQBUFy/kOud+yMZXWIVsIr+YiyVmkp0LVouu3jdNujqDKmP3D7QIzfGiCwMWAFS3+va1/Zknv+na2DE7mZq6cLPtvpMP7nu/N52tcXlD1d43zbsClhJj9w+0CM3xoguXaQCA6hL5he7JM5suvrGQaPug7+mWxendp57vu3pWV+8bE6YrYfD+gRahOV50YcACgJpyVvzi+vQrO8bf2To8/ujhRO72I5++pbH3h/YrYNXNzP0DXcJxvOjCgAUANWWtlitrNwohCjI+vbLrRntPx+ykxt4fhqxgCVP3D3QJx/GiCwMWANSUyM13f3du1fzUXYvTqekvOm9dmmrv1tj7w5wVLDP3D3QJx/GiCye5A0BN8Xx+w+SpJ84c+L6zd/Nnb99qXXe+d6fG3h/mrGCZuX+gSziOF11YwQKAmrLWiq/uHkhd/+SXf/+DEOLdLS9+uH67xt4f5qxgmbl/oEs4jhddWMECgJry8eTn67Yd37S/Y+7HmWTnotsPkKv2/jBnBcvM/QNdwnG86MIKFgC4yEtrqjXl/bu/at9s5qxgFZm2f6BX0I8XXRiwAKA6W8YWEu25mNc7vKr2vjFkBcvY/QMtQnO86MKABQDVzSXazvU99fWaB9zTunrfGLKCZez+gRahOV50sQy8ATU9PT29Cf2NlV1vDhxoXu+bWitYhu/PZm8Pvd5+2Y8X03oraQnbFgVb5Aruj5RS0NPT09P70+tS616Epr1eevow9dZ81iUqQ09PT0/vT69LrRUs014vPX2YerVzsFQXt+np6enp/ekdaDkHy7TXS09veM91sABERbyQ65z7IRtfYRWyifxiLpaYSXYuWC26emMZ8lOEqkKz/wMqsseLLgxYAKJi9e1r+zNPftO1sWN2MjV14WbbfScf3Pd+bzpb4wI8qr2xap2DZbjQ7P+AiuzxoguXaQAQFYn8QvfkmU0X31hItH3Q93TL4vTuU8/3XT2rqzdWEKcrEaL9H1CRPV50YcACECE5K35xffqVHePvbB0ef/RwInf7kU/f0tibyZDrYNUhHPs/uKJ5vOjCgAUgQrJWy5W1G4UQBRmfXtl1o72nY3ZSY2+mgK5gibDs/+CK5vGiCwMWgAhJ5Oa7vzu3an7qrsXp1PQXnbcuTbV3a+zNFNwVrHDs/+CK5vGiCye5A4iQeD6/YfLUE2cOfN/Zu/mzt2+1rjvfu1Njb6bgrmCFY/8HVzSPF11YwQIQIVlrxVd3D6Suf/LLv/9BCPHulhc/XL9dY2+m4K5ghWP/B1c0jxddWMECECH5ePLzdduOb9rfMffjTLJz0e0HyFV7MwV3BSsc+z+4onm86MIKFoDIyUtrqjXl/bu/am+a4K5gFQV9/wdd1I4XXRiwAESFLWMLifZcLNGk3lgBXcEKzf4PqMgeL7owYAGIirlE27m+p75e80CTemMFdAUrNPs/oCJ7vOhiSal2C0N6enr6gPY3Vna9OXCgeb3q9vim1gpWyP68mr09UesDd7yY1ltJS9i2KNgiV3B/pJSCnp6ent6fXpda9yI07fXS04ept+azLlEZenp6enp/el1qrWCZ9nrp6cPUq52Dpbq4TU9PT0/vT+9AyzlYpr1eenrD+/LrYMULuc65H7LxFVYhm8gv5mKJmWTngtVS6yno6enpw9qHRkB/ilCVae+fqPUoUz5grb59bX/myW+6NnbMTqamLtxsu+/kg/ve701na1zQgp6enj6sfWjUOgcrZEx7/0StR5nyjwgT+YXuyTObLr6xkGj7oO/plsXp3aee77t6ttbj6enp6cPah0YUpith3vsnaj3KVDkHK2fFL65Pv7Jj/J2tw+OPHk7kbj/y6VsOT0FPT08f1j4cAnodrDqY9v6JWo+lqgxYWavlytqNQoiCjE+v7LrR3tMxO+nwFPT09PRh7cMhIitYwrz3T9R6LFVlwErk5ru/O7dqfuquxenU9Bedty5NtXc7PAU9PT19WPtwiM4Klmnvn6j1WKr8JHchRDyf3zB56okzB77v7N382du3Wted793p8BT09PT0Ye3DITorWKa9f6LWY6kqA1bWWvHV3QOp65/c/2VmpjX17pYXP1y/3eEp6Onp6cPah0NEfopQmPf+iVqPpaoMWPl48vN1245v2t8x9+NMsnPR7Qcy6enp6cPah0NEpith3vsnaj2WqjJgFeWlNdWa8v5E9PT09GHtgy46K1hFpr1/otajqPwkd1vGFhLtuVjC4+Pp6enpw9qHRkSmK9PeP1HrUaZ8wJpLtJ3re+rrNQ94fDw9PT19WPvQiMhPEZr2/olajzLymVezSrcwlFLtlof09PT09PX14/t+Wjx46eOfvlIcksrWokqrU7VGqLIVrBf+/29Mvv/T0zevt5KWsG1RsEWu4P5IKQU9PT09vT+9LrXOwTLt9dLTh6mX9vmJIyP9g6MTLq0QcnO/EIKenp6enp6efs+hrGv580Oav7x09KGzRu2fmBBicHTiyEi/l7qInp6enp6ent47pWmp7t6o/RPz8pjKvUlPT09PT09Pbxpz9s/P18HyMpd5/z3o6enp6enpo9CbxpD9U+VmzwAAAGgEAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmv18HSyHq4eNpcVQpvyL9PT09PT09PSmMWT/xFxrIcRQRoylvT47PT09PT09fUR605izf2KudeVj6Onp6enp6elNY9T+ka7pUqprg/T09PT09PSh7E/vzCrdkllKtVs4q/bj+xLeYx/2j9qABQAAIITY+1rWtkXBFrmC+yQkpUhaoqn9688pDFg++D9omqPgoC0DtAAAAABJRU5ErkJggg==\n",
187
+ "text/plain": [
188
+ "<PIL.Image.Image image mode=RGB size=800x224>"
189
+ ]
190
+ },
191
+ "execution_count": 6,
192
+ "metadata": {},
193
+ "output_type": "execute_result"
194
+ }
195
+ ],
196
+ "source": [
197
+ "img = convert_level_to_png(dataset.input_ids[:700], TILE_DIR, mario_lm.tokenizer)[0]\n",
198
+ "img"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "id": "28a7e683-a9a2-4321-b21a-807daf7aa744",
204
+ "metadata": {},
205
+ "source": [
206
+ "#### Set device"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 7,
212
+ "id": "7a6f684a-63a9-4a34-9a57-fd6aa84375a0",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "device = torch.device('cuda')\n",
217
+ "mario_lm = mario_lm.to(device)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "id": "3869772f-e3a6-43d4-94ee-40364028bea8",
223
+ "metadata": {},
224
+ "source": [
225
+ "## Generating Levels"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 45,
231
+ "id": "1e7589f2-2b48-4174-9fc7-7e7de7ff3615",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "prompts = [\"many pipes, many enemies, some blocks, high elevation\"]"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "id": "aa0a437f-4123-44b2-b08f-985f60165fb2",
241
+ "metadata": {},
242
+ "source": [
243
+ "##### We generate 1399 predictions for an even 1400 output (including the input seed which is just a single block). Mario Levels have height of 14, so we generate 100 columns."
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 46,
249
+ "id": "766362fb-8b90-43a4-b405-17fed2342d31",
250
+ "metadata": {
251
+ "scrolled": true,
252
+ "tags": []
253
+ },
254
+ "outputs": [
255
+ {
256
+ "name": "stderr",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "shape: torch.Size([1, 685]), torch.Size([1, 1400]) first: \n"
260
+ ]
261
+ }
262
+ ],
263
+ "source": [
264
+ "generated_level = mario_lm.sample(\n",
265
+ " prompts=prompts,\n",
266
+ " num_steps=1399,\n",
267
+ " temperature=2.0,\n",
268
+ " use_tqdm=True\n",
269
+ ")"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 47,
275
+ "id": "777f94cf-a765-4f7a-a7b4-223c29680e17",
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "data": {
280
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAABkAAAADgCAIAAADUoj0kAAAwcElEQVR4nO3dXYxUVaLo8bW7qqjutpvuAJ5SMbT0U8vJmKYvx4+GuSHj5KRtDg9jEOOZHEXo+EHfhJmEHpMJxkzo3AfxwROPGGYa7MTxgVHHBw6t4WFCmMuHIgw3BqPoxQ+0BZUW+4P+qKre92FjWVRX79qraq9da+39/4UY6P7XZlXXrtqL5a5d1qbdaeGZZQnb9p7T09PT09PT09PT09PT09PT09NX2tdI5EJu6/T09PT09PT09PT09PT09PT09JX3cgtYAAAAAAAAQMBYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1mosS+4G9PT09PT09PT09PT09PT09PT0QfbxZFzYtpi1RWZW2HbprdPT09PT09PT09PT09PT09PT0wfZW5t2p0tU19+g5Ebp6enp6enp6enp6enp6enp6el97OWugSW1dXp6enp6enp6enp6enp6enp6+sp7LuIOAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACt1ViW3A3o6enp6enp6enp6enp6enp6emD7OPJuLBtMWuLzKyw7dJbp6enp6enp6enp6enp6enp6enD7K3Nu1Ol6iuv0HJjdLT09PT09PT09PT09PT09PT0/vYy10DS2rr9PT09PT09PT09PT09PT09PT0lfdcxB0AAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFqrsSy5G9DT09PT09PT09PT09PT09PT0wfZx5NxYdti1haZWWHbpbdOT09PT09PT09PT09PT09PT08fZG9t2p0uUV1/g5Ibpaenp6enp6enp6enp6enp6en97GXuwaW1Nbp6enp6enp6enp6enp6enp6ekr7+NytwAAABEQm800T36Xji2Iz6YT2ZlMTWI82Twdr61WDwAAgIhjAQsAABRadPVS79BDXy1pb5oYTo2c/aHh1sN3bH2ntTsdS1alBwAAQMTJvYUQAABEQSI73TJ8fOW5V6YTDe+2PVI7M7rhyLa2iyer1QMAACDiWMACAABFZOKxc8u6X+gafOPOvsF79yQyV+/5aH8VewAAAEQZC1gAAKCIdLz2wo3tQohZKzZat+RK4/KmieEq9gAAAIgyFrAAAEARicxUyzenFk6N3DAzmhr9rHns/EhjSxV7AAAARBkXcQcAAEXEstnbho88cHznt82tHR+/Nla/9HTr+ir2AAAAiDIWsAAAQBHp+IIvblqd+v7D2z8fGq9PHVr11Jlla6vYAwAAIMpYwAIAAEVkY8lPlq45sLK3afLyeLJ5Jpasbg8AAIAoYwELAADMK2vFR+pT+vQAAACIJi7iDgAACtlWzXSiMVOT0KQHAABAxLGABQAACk0mGk61Pfzl4hWa9AAAAIg469GX0rYtcwNL0NPT09PT09PT09PT09PT09PTB9bHk3Fh22LWFpnZ0re0LEFPT09PT09PT09PT09PT09PTx9kb23anS5RXX+Dkhulp6enp6enp6enp6enp6enp6f3sZe7BpbU1unp6enp6enp6enp6enp6enp6Svv43K3AAAYIjabaZ78Lh1bEJ9NJ7IzmZrEeLJ5Ol5rSg93UXu8dBsPAAAAAsYCFgCE06Krl3qHHvpqSXvTxHBq5OwPDbcevmPrO63d6VjSiB7uovZ46TYeAAAABEzuLYQAAFMkstMtw8dXnntlOtHwbtsjtTOjG45sa7t40pQe7qL2eOk2HgAAAASMBSwACK1MPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iHu6g9XrqNBwAAAEFiAQsAQisdr71wY7sQYtaKjdYtudK4vGli2KAe7qL2eOk2HgAAAASJBSwACK1EZqrlm1MLp0ZumBlNjX7WPHZ+pLHFoB7uovZ46TYeAAAABImLuANAaMWy2duGjzxwfOe3za0dH782Vr/0dOt6g3q4i9rjpdt4AAAAECQWsAAgtNLxBV/ctDr1/Ye3fz40Xp86tOqpM8vWGtTDXdQeL93GAwAAgCCxgAUAoZWNJT9ZuubAyt6mycvjyeaZWNKsHu6i9njpNh4AAAAEiQUsAAi5rBUfqU+Z28Nd1B4v3cYDAACAYHARdwAIJ9uqmU40ZmoShvZwF7XHS7fxAAAAIGAsYAFAOE0mGk61Pfzl4hWG9nAXtcdLt/EAAAAgYNajL6VtW+YGlqCnp6enp6enp6enp6enp6enpw+sjyfjwrbFrC0ys6VvaVmCnp6enp6enp6enp6enp6enp4+yN7atDtdorr+BiU3Sq9Vv2L8+YIvLjz8h9zvR9c+U/Ddszf8JlLb1+3xoqenp/er33fXyb07Orf0Hysdd3QKIezTpct8j574F6nx6Pbzoaenp6enp5/blzF/kOpZf6CvpJf7FEKprdNr2Oev/jh/nLsGFOXt09PT04ep39Jfek7pzCaFEHt3dM79bs+QGOgucqueIbkJqJ4/H3p6enp6evq5vdT8oYxedjz09Lmei7hHSMHqj8sXo7l9GCc2m1k8cXHh1Miiq5dSYxcWT1xMZqYM6mXpNh7VTB+/LEWPrzOnnG8j5c0mjRC1/QcAAB/Jzh8iO99AwOTOwIK5XBZ6KjyPKRzbh4kWXb3UO/TQV0vamyaGUyNnf2i49fAdW99p7U7Hkkb0ut1f3Zg+flnqHl/3OWVOz5Dc1zUXtf0HAAB/eZw/lN0DZeAMrEhYePgPjx8cf/zg+NxvOV+v8Dwm07cPQyWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSlF6WbuNRzfTxy6r64zvQfe1X/u/zfxV8XXNR238AAABCjwWs8HMWd/asa7BPvF2wBvT4wXH7xNt71jWICt6LZ/r2YbRMPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iXpdt4VDN9/LKi9viqxs8HAAAgTHgLYeQUPY+J7cNQ6XjthRvbhRCzVmy0bsmVxuVNE8MG9bJ0G49qpo9fVtQeX9X4+QAAAIQJC1ghN/e8JPvE23/cuUEI8djTr1t3dxXG634Tqe3DdInMVMs3pxZOjWRr4qnRz5rHzp+/5ecG9bJ0G49qpo9fVnUf3/xrXXm5Htamfyv7rwpI1PYfAACAcGMBK1rsE2+7/JHtwzixbPa24SMPHN/5bXNrx8evjdUvPd263qBelm7jUc308cuq7uNb9LJWPUPzfl1/Udt/AAAAwo0FrAhxTlya749sHyZKxxd8cdPq1Pcf3v750Hh96tCqp84sW2tQL0u38ahm+vhlRe3xVY2fDwAAQJiwgBVyo2ufyX8X3mNPv14Q5C8Dja59Jmrbh+myseQnS9ccWNnbNHl5PNk8E0ua1cvSbTyqmT5+WVF7fFXj5wMAABAmLGCFX25ZZ/vPfzb3u489/fpzf38/yttHCGSt+Eh9ytxelm7jUc308cvy/fG1Ojrn+9ZA909vBvRy3SsTRW3/AQDAFx7nD2X3QBlYwILo6+sTQuzatYvtwyy2VTOdaMzUJAztZek2HtVMH78sRY+vy2xS/HiJK2dOmbvWlct1r/K//n88DrRKorb/AADgI+/zh/J6oDw11R4AAtLX11fwmX1CCOvurtwXnWWgyG4fJppMNJxqe/jLxSsM7WXpNh7VTB+/LBWPr/ts0jHfcpXporb/AADgF9n5Q5TnGwhY3LKEbUvcgN6svvDmd3flPrlv7npQBLdPb3R/pW7Jn1fvNLcvSfPxqO5NH79s7//j62E26ShvThmyn6fq8dDT09PT05vRS84fZPujut1ferP6J/6Utm0xa4vMbOlbWpZIxgW9Qf2K8edzf8ydo+SsARUsADlvwft/zb+J1PZ1e7zo6enp/epfflLtu+eYP9DT09PT04evZ/5Ar3NvbdqdLlFdf4OSG6VX2qt+QTEd+zN9mHrVz3eeL2b1+Qv6jvwPaZ37Mazb7129d0fnlv5jpTfe0SnYH+jp6enp6ektse+uk1LzB/v0MeYb9IH1chdxl9o6vaL+uQ/kbuXd9hVh2L53ej6+9PT5Pc8X+vn6/NUr549z17C29JeeU3o/8999PPT09PT09PTh6GXnD8w36APr+RRC8/T/7bo/7vhFiSBq24e5YrOZ5snv0rEF8dl0IjuTqUmMJ5un47V+9brxMn6eL97ptj8o3Z8LVq9yX5RdwypvNhlKuu0/AABUkez8gfkGgsEClnly/4J1+Yer07j/y3bHL0T/3376rxBieyi2D3Mtunqpd+ihr5a0N00Mp0bO/tBw6+E7tr7T2p2OJX3pdeNl/DxfvNNtf1C3Pxddvcp9a741rArvTujptv8AAFBdsvMH5hsIQE21B4CqKfjXLNtH1SWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSr143QY4/Cs8X3fYHRfvzwsN/ePzg+OMHx+duwfm6y/IWXOi2/wAAAKAAC1jRlX9GBtuHJjLx2Lll3S90Db5xZ9/gvXsSmav3fLTfx143gY0/Is8X3fYH3/dnZ3Fqz7oG+8TbBWtYjx8ct0+8vWddg3A9RQsudNt/AAAAkI+3EEaX6WdIReGMkghKx2sv3NguhJi1YqN1S640Lm+aGPax101g44/I80W3/UH1/lz0PCyUTbf9BwAAAPk4Ayu6TD9DKiJnlERNIjPV8s2phVMjN8yMpkY/ax47P9LY4mOvm8DGH5Hni277g7/789zzqpxTrpwTsgq+xUlYZdBt/wEAAEA+zsCKLtPPkIrIGSVRE8tmbxs+8sDxnd82t3Z8/NpY/dLTret97HUT2Pgj8nzRbX9Quj8XLFrNXcOCLN32HwAAAORjASu65n4qGdtH1aXjC764aXXq+w9v/3xovD51aNVTZ5at9bHXTWDjj8jzRbf9Qd3+/MedG1z+iPLotv8AAAAgHwtY0WX6GVJR+Nd4BGVjyU+Wrjmwsrdp8vJ4snmm1AfYy/a6CWz8EXm+6LY/+Ls/j659Jv+NgY89/XrBzfOXsUbXPlPx8CNHt/0HAAAA+VjAii7Tz5CKyBkl0ZS14iP1KXW9bgIYf6SeL7rtDz7uz7llqe0//9nc7z729OvP/f39uV+3Ojrn+7sGukXPkPehRYJu+w8AAFUhO39gvoEAcBH36DL9DKno/Gs8OmyrZjrRmKlJKOp1E+T4o/B80W1/qMr+3NfX19fXl/8Vl9mkEKJnSAx0V/IXhodu+w8AAFUkO39gvoFgsIAVXaZ/SmBEPlUtUiYTDafaHv5y8QpFvW6CHH8Uni+67Q/q9ue+vj7r7q6CL1p3d+W+mFvDcp9NOphTOnTbfwAAqBbZ+QPzDQTGevSltG3L3MAS9FXsB7fyP4fdsD/Th6lX/Xzn+WJW/88Tzzu/ya1P5T55cO561q5duwpOxXI30C2Ormd/oKenp6enj3ovNf+UfW8g8w36CvuaZFwk4yIRE5blaev01e1t27alHmHP7Dzmbl+3x4uevpKe5wt9fp+za9eua1+cs25VEHjXM6Td/aWnp6enp6cPvpcie2Ur5hv0FfbxqXTpLh99dXvrxwf2uQ+ufcV5Q1DBdW1yV7qZ7+1CBVfD2b7ipy0bvf1Nu+V+oLo9vvT0+VQ/33m+VLd/6X+c3Lujc0v/sZKlc2a+ffpYrt9++rpb2df/cW7vZfu6/Xzo6enp6enpg+9l5w/MN+iD7OWugSV7KgB9dXsXqq+GY8T2dXu86Okr6V3wfNG239J/bO+OEpeNyL+uhOreOz1/nvT09PT09PSV98w36LXtuYh7RKn+PDLTtw9fxGYziycuLpwaWXT1UmrswuKJi8nMlI89POL5UpTq/dNj7z7nmzvbU93DI16vAABeaDLfkMV8A3qKV3sAqI7cGRmK/k1r+vbhi0VXL/UOPfTVkvamieHUyNkfGm49fMfWd1q707GkLz084vlSlOr903vv5f9bBtnDC16vAABe6DPfkMV8AxriDKyIMv0MKf41boREdrpl+PjKc69MJxrebXukdmZ0w5FtbRdP+tXDI54vRaneP9mfw43HFwDgBfMNwEcsYEWUEdeoquL24ZdMPHZuWfcLXYNv3Nk3eO+eRObqPR/t97GHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkjHay/c2C6EmLVio3VLrjQub5oY9rGHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkhkplq+ObVwauSGmdHU6GfNY+dHGlt87OEFz5f5qN4/2Z/DjccXAOAF8w3AL1zEPaJMP0OKM0pMEctmbxs+8sDxnd82t3Z8/NpY/dLTret97OEFz5f5qN4/2Z/DjccXAOAF8w3AL5yBFVGmnyHFGSWmSMcXfHHT6tT3H/7P//tfQohDq546s2ytjz284PkyH9X7J/tzuPH4AgC8YL4B+IUzsCLK9DOkOKPEFNlY8pOlaw6s7G2avDyebJ4p9YG+sj284PkyH9X7J/tzuPH4AgC8YL4B+IUFrIjKnZGh6N+0pm8f/spa8ZH6lLoe7ni+uFO9f5bsrY7O+b410C16hoLuIYXXKwCAF1Wfb8hivgEN8RbCiDL9DCn+NW4E26qZTjRmahKKenjE86Uo1funx95ltieE6BkSA92B9vCI1ysAgBeazDdkMd+AnljAiijTr1HFNX2MMJloONX28JeLVyjq4RHPl6JU759eevfZniN/zqe6h3e8XgEAvNBhviGL+Qa0FbcsYdsSN6Cvbu8X08+Qmm/7uj1eEe+v1C358+qd6nrZ8RjX+4XnS9Fe9f5Zuvcw23M4cz7V/VG9Hy/del6v6Onp6em99NWfb8iOn/kGvcZ9TTIuknGRiAnL8rR1+ur2fjH9DKn5tq/b40VPX0nvF54vevZSZK8cUUav28+Hnp6enp6ePvheCvMN+oB7a9PudOkw7wZaLb+V0a8+IPH24C39R/fuWC3VP3riTqnx6DZ+03vVP396+iD7l5+89nx/7oNrX3QWoQrOpcqdXTXfknHBGVjbfzzHPGqv/6r7fXed3Lujc0v/sdJxR6cQwj59TLe+ZJnv0RP/4j3W8PGip6enp6fXoZedP6iev6mezzD/pK+kl/sUQqmta9uvu3+5x97qWD3QLddLPSE1HL/pfQA/f3p6PXsXvnwKoW73V89+S3/pOVz+mfZa9Xt3FHkLwHxXrOgZkpuA6vl40dPT09PT69DLHt91G49u46cPcR/Ri7jf3POpVCDbq6Z6/Kb3UCQ2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87KEIn0LoC4/7szOHm28jc2dvuvVQhNdDAAgHRfPhso/Xio4vzDegJ7kzsELj64FrJ/Uc/Ot1Cx+5k31yQXm9aqrHb3oPRRZdvdQ79NBXS9qbJoZTI2d/aLj18B1b32ntTseSvvRQxJczsOB9f3afw82lST/fZSxkL2+Bong9BIBwUDcflj2+lzce71TPT4AyRPQMLIezGtIzdO2XmLM+UmGvmurxm97Dd4nsdMvw8ZXnXplONLzb9kjtzOiGI9vaLp70q4cirF75IvT780D3tV/5v8//VfB1SAn9/gMAEaHbfJjjCyIl0gtYOV8PLM+dwuPlzWiyvWqqx296Dx9l4rFzy7pf6Bp8486+wXv3JDJX7/lov489VFD9qaDRwf6MSrD/AEA46DYf5viC6GABC4CEdLz2wo3tQohZKzZat+RK4/KmiWEfe6jAGVh+YX9GJdh/ACAcdJsPc3xBdET0GliOdfcvz52/4+XNaLK9aqrHb3oPFRKZqZZvTi2cGsnWxFOjnzWPnT9/y8997KEC18DyS7j35/xrXXm5Htamf1M7nvAJ9/4DANGh23yY4wuiI9ILWOLHi3/nLq7ke6+a6vGb3sN3sWz2tuEjDxzf+W1za8fHr43VLz3dut7HHiqweuWXcO/PRS9r1TM079chK9z7DwBEh27zYY4viI6oL2AVkP0wO90+/E71+E3vUbl0fMEXN61Off/h7Z8PjdenDq166syytT72UIEzsPzC/oxKsP8AQDjoNh/m+ILoYAFLiOvfm6aiV031+E3v4aNsLPnJ0jUHVvY2TV4eTzbPlPqAXtkeKrB65Rf2Z1SC/QcAwkG3+TDHF0QHC1hCCHHwr58OdIt193s9nUe2V031+E3v4busFR+pT6nr4S/OwPJXyf3Z6uic71sD3UXefKdJ7+W6V6gcr4cAEA6+z4dlj+8Vjqck1fMToAx8CuFPZE/q0e0kINXjN71H5WyrZjrRmKlJKOqhCKtXvvC4P7vM3kSxS0rp0w90X/uV//v8XwVfhxReDwEgHBTNh2WP72WPxyPV8xOgPBFdwCpY+1h3//KCJeGCQLZXTfX4Te+hyGSi4VTbw18uXqGohyK5M7BQCS/7s/vszZE/h9Othzq8HgJAOKiYD1dyvFZxfGG+AW3FLUvYtsQNTO8dcy/+XfAV9z+6BKaP3/Ret/0tZP2VuiV/Xr1TXS87ntD3fpnvDCzd7q/mfen92cPszeHM4XTrZWn+eOnW83pIT09PH47e//mw5PH6qOrji+L5xlG9H1963fsn/pS2bTFri8xs6VtalkjGhdH9y0+qPXtf9c9T9fhNF7X9mT7cfe75/twH177onEVVsBSVW5xyOccqfw1r+4//i47ni7991F6f2X/o6enp6ekr72XnD6b/e5P5A30lvbVpd7pEdf0NSm6Unp6ent6Xft9dJ/fu6NzSf6x03NEphLBPH5Pqef1371X//HXr2R/o6dX1Azfvd37/qzO/dinfbH/V+c2W4QeVjoeenl5db/r8zfTx04e7l7sGltTW6enp6ekr7Lf0H9u7o8SJ2flnbsv2suOJWq/6569b752ejxc9vc593cbNdRs3v/2/k2+2vzrfr1wWwHjo6enV9aYfr00fP32I+4hexB2AECI2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87OEL9znB3NmAbA93qn/+uvXQBK+34TP5l32Tf9knflyiKvorPwMQDEXzYdOP16aPH2EVr/YAAFTNoquXeoce+mpJe9PEcGrk7A8Ntx6+Y+s7rd3pWNKXHn7x8v+1KunhTvXPX7ceOuD1NsRYnwK0om4+bPrx2vTxI5Q4AwuIrkR2umX4+Mpzr0wnGt5te6R2ZnTDkW1tF0/61QMAysPrbfh4fG9gGW8hBFAJ5sOAQTgDC4i0TDx2bln3C12DNXb2H8u7fvvmL+/5aP/7S9f41QMAysPrbVi5rE9xchZQFcyHAVNwBhYQael47YUb24UQs1ZstG7JlcblTRPDPvYAgPLwegsAwWA+DJiCBSwg0hKZqZZvTi2cGrlhZjQ1+lnz2PmRxhYfewBAeXi9BYBgMB8GTMFbCIFIi2Wztw0feeD4zm+bWzs+fm2sfunp1vU+9gCA8vB6G2757xbkoldAdTEfBkzBAhYQaen4gi9uWp36/sPbPx8ar08dWvXUmWVrfewBAOXh9TbEJv+yL3/R6q1Vdfe9N1nF8QARx3wYMAULWECkZWPJT5auObCyt2ny8niyeabUB7TL9gCA8vB6G1a51SvnJKy6jZvve2+SNSygipgPA6ZgAQuAyFrxkfqUuh4Vsjo65/vWQLfoGaq0hzvVP3/demiF19uwyi1jsXQFaML3+bDpx2vTx49Q4iLuQHTZVs10ojFTk1DUwxcuswEhRM+QGOiuqIc71T9/3XpogtdbAAiGovmw6cdr08ePsGIBC4iuyUTDqbaHv1y8QlGPyrnPBhz5cwLZHu5U//x166EPXm9Dr27j5rdW1eWffpV/WXcAgVExHzb9eG36+BFi1qMvpW1b5gaWoKenp6cPoB/cKnHyhey52QPd4uh6Xv/dqP7569azP9DTq+v33rLf+U3dxs0FV3B3OCtZuTWsnq8f1Gr89PT03pk+fzN9/PQh75/4U9q2xawtMrOlb2lZIhkX9PT09PQB9C8/qfbdQ7z+u/eqf/66YX+gp1fXD9y831m06vr99JvtrxZdwPrPf72S+9bmFye0Gj89Pb0+8wfVx2vTx08f7t7atDtdorr+BiU3Sk9PT09Pr2G/766Te3d0buk/Vjru6BRC2KePRapnPkBPr66/fOna5yYVXb1yOGtYzu8X/VNG6XhWH5D4B+qW/qOPnrhTavsDN1874+xXZ37tUr7Z/uq1v2L4Qant6/b4qu5XjD9f8MWFh/+Q+/3o2mcKvnv2ht8oHU/U7q9u8wfVx2vV95f5Bn0lvdynEEptnZ6enp6eXqt+S3/pOVb+dRyi1nun5+NLTx+OXlZ541l3/3KPvdWxWuofnM72f1yn2+x+bS8n+/fnJRawdHu8gu/zV3OcP85d0wlyPKp7He5v1I7Xpo+fPsR94UXcY7OZxRMXF06NLLp6KTV2YfHExWRmSm6rAHwi+3xU3QMh4Myx5vvu3NlV1HoYitdzI7icfiWEuO+9yW2HmgMcjri559MKAxeTf9nnLF3Vbdw836/8DB4VrOa4fDEcyri/iubDoTlea3J/gfIUnoG16Oql3qGHvlrS3jQxnBo5+0PDrYfv2PpOa3c6lqzK+IAok30+qu6BcHCfY9HDRLyeowxfD1w7CevgX69bqMqdnJULKsH6lI9cFm4qPC9JT+XdX3Xz4XAcr/W5v0AZCs/ASmSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28WRVBgdEnOzzUXUPANATr+f6y51+9daquqK/RDVOwhI/rl71DF37JeasZ5Und46VLxmEEAsP/+Hxg+OPHxyf+y3n6yE7D6vs+8t82F3U7i9Cpsg1sDLx2Lll3S90DdbY2X8s7/rtm7+856P97y9dE/zgAMg+H1X3AAA98XpuBGehSgix9ncv5r54+Nle51v3vTdZnWEJIX4838pZvbq551NfTr8SP10MqwhOzvLOWazZs67hsadft+7u2rOuIfetxw+O2yfe/uPODSJE52FVeH+ZD7uL2v1FmBRZwErHay/c2C6EmLVio3VLrjQub5oYDnpcAIQQ8s9H1T0AQE+8nuus4OpX+atXzh+dNSzAo6LnJYWY7P1lPuwuavcXYVL4FkIhRCIz1fLNqYVTIzfMjKZGP2seOz/S2BL8yAAI+eej6h4AoCdez82SW7EqWLoK/l2E6+5f7rxzUPj05kH4bu575ewTb+9Z17BnXYN94u2SsXEqv7/Mh91F7f4iTIqcgRXLZm8bPvLA8Z3fNrd2fPzaWP3S063rgx8ZACH/fFTdAwD0xOu5Keaeb3X42d7qnoSV/+bB3GKWv/LfLchFrypUsIgzd00nZMq4v8yH3UXt/iJMir6FcMEXN61Off/h7Z8PjdenDq166syytYEPDIAQ8s9H1T0AQE+8nsMvfl39KmfyL/vyF62qfrUvoznXfprvj+FT3v1lPuwuavcXYVJkASsbS36ydM2Blb1Nk5fHk80zfAAzUD2yz0fVPQBAT7yeo0Lr7l9+c4//byHMrV45J2HVbdx833uTrGFJGV37TP4b5R57+vWCIH9ZJwQXca/8/jIfdhe1+4swKbKA5cha8ZH6VJBDATAf2eej6h4wmtXROd+3BrqLvH0maj2Mxuu5EfIv4l5wQfcqOvjXTwe6xbr7fT79ypFbxmLpqjy5ZZrtP//Z3O8+9vTrz/39/WBHpJYv99f3+XDIjtdVv79AGQov4m5bNdOJxkxNoiqjAZBP9vmougdCwGV2JYToGRID3ZHuYShezw1y+Nle53JXc3+jAxUnYSEAfX19fX191R5FcFzur6L5cGiO15rcX6A8hQtYk4mGU20Pf7l4RVVGAyCf7PNRdQ+Yzn125cifY0Wth7l4PTfR2t+9WN3TrwrWqvI/jrBoUIm6jZvfWlWXf/pV/mXd4VFfX591d1fBF627u3JfDNkaVnn3V8V8OEzHax3uL1A269GX0rYtcwNL0NPT09PTG9cPbpU4OUX2XPcQ9EfXMx+gp1fV771lf+4q5m+tqiva5FZ23lpV9+f/+E7peNb8d8L7WwVv7vlU9t8Le2/Z7/ymbuPmgiu4O5yVrNwaVs/XD2r1eOnW//PE885vcus1uU/im7u+s2vXrg8afqPV+E2/v7rNH1Qfr1XfX+Yb9JX08WRc2LaYtUVmtvQtLUvQ09PT09Ob2EuRvVJDCPonfqXX40VPH6Y+n5crQKkeT8+QEEOflhxG2eMRQjiLVl2/nxbi1aKZ8y0n2/zig1o9Xrr1Obt27XLWdKy7u3JrOvl27dol1O8/Ubu/UkJwvJYdj2zPfIO+kt7atDtdorr+BiU3Sk9PH1j/8pNqL3fC6wO9zv2+u07u3dG5pf9Y6bijUwhhnz5G79LzfKenV9cP3HztjKRfnfm1S/lm+7W1ni3DDyodj+r+8qVrnxP1Zvurc0+/cry1qu4///WK8/tF/5RROh56+vze9PkDx2v6KPfzfgphUVJbp6enD6B/7gO5W3m3XfJSKnr+fOjD3W/pLz3ny78uA717752e+wM9vc79j+s4m92v/eRk//68xAKWnvdXHd3GT29ir9vxl+M1Pb3HXu4MLMRmM82T36VjC+Kz6UR2JlOTGE82T8drq9XDLL7vD4NbE83/dd1NdvyicCP9fytztFf+l9z/4QEC9vLdJ53fuMz58md79ulj9C49z/dIYb4RsPxrQrlkubUtqTOwNDTyTVy4nn7lyJ2EJXUGFqLG9/mz6fOHgI/XHC/MEvr1CrkzsLDo6qXeoYe+WtLeNDGcGjn7Q8Oth+/Y+k5rdzqWrEoPs6jYH3IrVi4LVU7jvpK14xei/28//VcIsV3uzgFV4/x/S3q/ekQB840q4tP3ACnq/j2l2/FXz+M1xwuzhH69oiaYvyY0EtnpluHjK8+9Mp1oeLftkdqZ0Q1HtrVdPFmtHmbReX8oWL0CAIQY843g1W3c7H46klSmv9zpV2+tqiv6Swhx33uT2w41V3mg0J7O8+co4OdpltCvV3AGlrRMPHZuWfcLXYM1dvYfy7t+++Yv7/lo//tL11Srh1m03R/mnoEFAAgx5hvV4rI+Fb6Ts5yFKiHE2t+9mPvi4Wd7nW95+UBGQGg8f44Ifp5mCfd6BWdgSUvHay/c2C6EmLVio3VLrjQub5oYrmIPs2i7P7B6BQCRwnwD6hRc/Sp/9WruH4GStJ0/RwQ/T7OEe72CBSxpicxUyzenFk6N3DAzmhr9rHns/EhjSxV7mEXb/SH/DCwAQOgx30CQnLOu8n/j4F2E8ELb+XNE8PM0S7jXK3gLobRYNnvb8JEHju/8trm14+PXxuqXnm5dX8UeZtF2f+AMLACIFOYb1ZX/bsFwXPRqPmt/92LBotXhZ3vnfhFwoe38OSL4eZol3OsVLGBJS8cXfHHT6tT3H97++dB4ferQqqfOLFtbxR5m0XZ/4BpYABApzDeqaPIv+/IXrbgaFOBO2/lzRPDzNEu41ytYwJKWjSU/WbrmwMrepsnL48nmmVIfGKm6h1m03R9YvQKASGG+US251SvnJKy6jZvve2+SNSzAhbbz54jg52mWcK9XsIBVpqwVH6lP6dPDLBruD5yBBXNZHZ3zfWugW/QM0cv1iBTmG9WSW8aKwtJV/lXbuYI7yub7/Fm346/mx2uOF2YJ63oFF3GXY1s104nGTE1Ckx5m0Xl/YPUKhnKZ7QkheobEQDe9RI+IYL6BwBx+tte53NXc3wBeKJo/63b81fZ4zfHCLKFfr2ABS85kouFU28NfLl6hSQ+z6Lw/8CmEMJH7bM+RP+ejd+8RHcw3qq5u4+a3VtXln36Vf1n3UFr7uxc5/QplUDF/1u34q/PxmuOFWUK/XmE9+lLatmVuYAl6enpN+sGtahe/eX2g17mX2v9lz72PYH90Pc93enpV/d5b9ju/qdu4ueAK7g5nJSu3htXz9YNajb+M+5u7j2+tqiva5Fbu3lpV9+f/+E6r8dOHuzd9/sDxmj7KfTwZF7YtZm2RmS19S8sS9PT0+vS2bQshLMsqkcpztvzkQEar+0tPn99Lkb1yRAT7J36l1+NLTx+mXgjhLOh0/X5aiFeLZs63nGzziw9qNf5KXp+9XOFLt/HTh7uXwvGanl6r3tq0O12iuv4GJTdKT08fWP/yk9f+D9JzH1z7ovMGwILrWOWubDXf2wMLrn61/cdzQnl9oA+y33fXyb07Orf0Hysdd3QKIezTx+h97Hm+m9XLPl94fOnp6cPaR23+wOs5fSW96fMHuWtgSW2dnp5eq96FL1e/0u3+0pvYb+k/tndHictA5F8ngt7f3js995+o9Ty+9PT09NGcP3in5+NFX93e6P0tXvCl2GymefK7dGxBfDadyM5kahLjyebpeO18m6B371XT7f7Sm7X/5PD5g9CHc0yd7/8LzT2a0vvbwyyBPb66HR/p6enpCzLdjqeRPV5rsj/Q+/t8KZvv4y9cwFp09VLv0ENfLWlvmhhOjZz9oeHWw3dsfae1Ox1LFv0L6N171XS7v/Rm7T85uTOwWMOCDrz8fyF6dT3MEszjq9vxkZ6enn5ur9vxNJrHa332B3p/ny/l8X38hW8hTGSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28eR8A6J371XT7f7Sm7X/5LB6BQBwodvxkZ6enn6+HtWl2/5AH7L1iiLXwMrEY+eWdb/QNfjGnX2D9+5JZK7e89F+lzHRu/eq6XZ/6c3afxy+XAMLABBiuh0f6enp6aEn3fYH+jCtVxRZwErHay/c2C6EmLVio3VLrjQub5oYdvkL6N171XS7v/Rm7T8OzsACALjT7fhIT09PDz3ptj/Qh2m9osgCViIz1fLNqYVTIzfMjKZGP2seOz/S2OLyF9C796rpdn/pzdp/HJyBBQBwp9vxkZ6enh560m1/oA/TekXhRdyFELFs9rbhIw8c3/ltc2vHx6+N1S893bre5S+gd+9V0+3+0pu1/zg4AwsA4E634yM9PT099KTb/kAfpvWKIgtY6fiCL25anfr+w9s/HxqvTx1a9dSZZWtd/gJ691413e4vvVn7j4NPIQQAuNPt+EhPT08PPem2P9CHab2iyAJWNpb8ZOmaAyt7myYvjyebZ+b5gEN6j71qut1fen/7YLB6BQBwp9vxkZ6enh560m1/oHfvVfN3/EUWsK7dzIqP1KckhkVfVbrdX3p/e9U4AwtasTo65/vWQLfoGaJX28MsAT++uh0f6enp6XN0O55G/Hhd9f2B3t/nS4X8Gn/hRdxtq2Y60ZipSXjcLn116XZ/6f3tA8PqFfThcjQVQvQMiYFueoU9zBLY46vb8ZGenp6+gG7H08gerzXZH+jdM3PnD4ULWJOJhlNtD3+5eIXHv4C+unS7v/T+9oHhUwihCfejqSP/mErvbw+zBPn46nZ8pKenp8+n2/E0ysdrHfYHen+fL5XwffzWoy+lbVtiBJYl6OnpNekHt15bnH7ug2tfcRahCs6lyp1dNd8SVcEZWNt/fMXg9YE+yD63P3she24zfcn+6Hqe7yb1ss8XHl96evqw9lGbP/B6Tl9Jb/r8oSYZF8m4SMSEZXnaOj09vT69X+Y7A0u3+0sf7l6K7Dvz6Uv2uu0P9O69FB5fenr6EPdSNDz+yva6/fzpzeqlaLi/WfbpY3t3dG7pP1a67ugUQtC795t2p0uWP91Efnly310ntbq/9PQ696qfj/T+9ry+0Ue5l3294vlCT09PT09PH7W+Rgixpf/Y3h0l3gaZ/z5JevfeO6l/3eV63e4vPb3OvXflPR/p/e1123/o6XV+vdJt/PT09PT09PT0SvsaL7eZO7uid+9V0+3+0tPr3MMsuu0/9PRB9rJ0Gz89PT09PT09vbo+7vE2c9FXl273l55e5x5m0W3/oacPspel2/jp6enp6enp6RX1Nd63CAAAAAAAAASPBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGiNBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGgtnvud1dE5XzTQLXqGCr9I796rptv9pafXuYdZdNt/6OmD7GXpNn56enp6enp6ekV9TclaCNEzJAa6vW6dPgC63V96ep17mEW3/YeePshelm7jp6enp6enp6dX19eUrOfeht69V023+0tPr3MPs+i2/9DT6/x6pdv46enp6enp6emV9lbJNN+A5LnuEeyPrk/btsRNLEtI9YNbE1Lj0e3nQ08fZK/6+Ujvb8/rG32Ue9nXK54v9PT09PT09FHr5RawUNITf0rbtpi1RWa29L/cLEsk40Kqf/lJiQkrEHGqn4/0/va8viHKZF+veL4AAICo+f/elNjhpnnY9wAAAABJRU5ErkJggg==\n",
281
+ "text/plain": [
282
+ "<PIL.Image.Image image mode=RGB size=1600x224>"
283
+ ]
284
+ },
285
+ "execution_count": 47,
286
+ "metadata": {},
287
+ "output_type": "execute_result"
288
+ }
289
+ ],
290
+ "source": [
291
+ "img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]\n",
292
+ "img"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "markdown",
297
+ "id": "7233c86a-eb02-48cb-8369-bb8a521bc330",
298
+ "metadata": {
299
+ "tags": []
300
+ },
301
+ "source": [
302
+ "#### Check if the model generated the correct level\n",
303
+ "##### Because of the stochastic nature of the model and the small training dataset, the model may generate levels that do not completely match the given prompt"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 50,
309
+ "id": "d3489875-e648-4c75-97f0-7ae55dc51b81",
310
+ "metadata": {},
311
+ "outputs": [
312
+ {
313
+ "data": {
314
+ "text/plain": [
315
+ "'some pipes, many enemies, some blocks, high elevation'"
316
+ ]
317
+ },
318
+ "execution_count": 50,
319
+ "metadata": {},
320
+ "output_type": "execute_result"
321
+ }
322
+ ],
323
+ "source": [
324
+ "mario_lm.prompter(generated_level)[0]"
325
+ ]
326
+ }
327
+ ],
328
+ "metadata": {
329
+ "kernelspec": {
330
+ "display_name": "Python [conda env:py39] *",
331
+ "language": "python",
332
+ "name": "conda-env-py39-py"
333
+ },
334
+ "language_info": {
335
+ "codemirror_mode": {
336
+ "name": "ipython",
337
+ "version": 3
338
+ },
339
+ "file_extension": ".py",
340
+ "mimetype": "text/x-python",
341
+ "name": "python",
342
+ "nbconvert_exporter": "python",
343
+ "pygments_lexer": "ipython3",
344
+ "version": "3.9.0"
345
+ }
346
+ },
347
+ "nbformat": 4,
348
+ "nbformat_minor": 5
349
+ }
setup.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+ from os import path
5
+
6
+ from setuptools import find_packages
7
+ from setuptools import setup
8
+
9
+
10
+ this_directory = path.abspath(path.dirname(__file__))
11
+ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
12
+ long_description = f.read()
13
+
14
+
15
+ setup(
16
+ name="mario-gpt",
17
+ version="0.1.0",
18
+ url="https://github.com/kragniz/cookiecutter-pypackage-minimal",
19
+ license='MIT',
20
+
21
+ author="Shyam Sudhakaran",
22
+ author_email="shyamsnair@protonmail.com",
23
+
24
+ description="Generating Mario Levels with GPT2. Code for the paper: 'MarioGPT: Open-Ended Text2Level Generation through Large Language Models', https://arxiv.org/abs/2302.05981",
25
+
26
+ long_description=long_description,
27
+ long_description_content_type="text/markdown",
28
+
29
+ packages=find_packages(exclude=('tests',)),
30
+
31
+ install_requires=[
32
+ 'torch',
33
+ 'transformers',
34
+ 'scipy',
35
+ 'tqdm'
36
+ ],
37
+
38
+ classifiers=[
39
+ 'Development Status :: 2 - Pre-Alpha',
40
+ 'License :: OSI Approved :: MIT License',
41
+ 'Programming Language :: Python :: 3',
42
+ ],
43
+ )
static/architecture.png ADDED
static/prompt-samples.png ADDED