File size: 6,627 Bytes
850b0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations

import random
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from scipy import stats
from transformers import pipeline

from mario_gpt.dataset import MarioDataset
from mario_gpt.utils import view_level

STATISTICS = {
    "enemy": np.array([1.0, 3.0, 7.0]),
    "pipe": np.array([0.0, 2.0, 5.0]),
    "block": np.array([50.0, 75.0, 176.0]),
}

FEATURE_EXTRACTION_MODEL = "facebook/bart-base"


class Prompter:
    def __init__(
        self,
        level_tokenizer,
        prompter_model: str = FEATURE_EXTRACTION_MODEL,
        use_raw_counts: bool = False,
        statistics: Optional[Dict[str, Any]] = None,
    ):
        self.prompter_model = prompter_model
        self.feature_extraction = pipeline(
            "feature-extraction",
            model=prompter_model,
            tokenizer=prompter_model,
            framework="pt",
        )

        self.level_tokenizer = level_tokenizer

        self.use_raw_counts = use_raw_counts
        self.statistics = statistics
        if statistics is None:
            self.statistics = STATISTICS

    @property
    def pipe_thresholds(self) -> Tuple[List[int], List[str]]:
        thresholds = self.statistics["pipe"]
        keywords = ["no", "little", "some", "many"]
        return thresholds, keywords

    @property
    def enemy_thresholds(self) -> Tuple[List[int], List[str]]:
        thresholds = self.statistics["enemy"]
        keywords = ["no", "little", "some", "many"]
        return thresholds, keywords

    @property
    def block_thresholds(self) -> Tuple[List[int], List[str]]:
        thresholds = self.statistics["block"]
        keywords = ["little", "little", "some", "many"]
        return thresholds, keywords

    def count_pipes(self, flattened_level: str) -> int:
        return flattened_level.count("<>")

    def count_enemies(self, flattened_level: str) -> int:
        return flattened_level.count("E") + flattened_level.count("B")

    def count_blocks(self, flattened_level: str) -> int:
        return np.sum([flattened_level.count(char) for char in ["X", "S", "?", "Q"]])

    def _flatten_level(self, string_level: List[str]) -> str:
        return "".join(string_level)

    def pipe_prompt(self, flattened_level: str, level: str) -> str:
        count = self.count_pipes(flattened_level)
        keyword = f"{count}"
        if not self.use_raw_counts:
            thresholds, keywords = self.pipe_thresholds
            threshold = np.digitize(count, thresholds, right=True)
            keyword = keywords[threshold]
        return f"{keyword} pipes", keyword

    def enemy_prompt(self, flattened_level: str, level: str) -> str:
        count = self.count_enemies(flattened_level)
        keyword = f"{count}"
        if not self.use_raw_counts:
            thresholds, keywords = self.enemy_thresholds
            threshold = np.digitize(count, thresholds, right=True)
            keyword = keywords[threshold]
        return f"{keyword} enemies", keyword

    def block_prompt(self, flattened_level: str, level: str) -> str:
        count = self.count_blocks(flattened_level)
        keyword = f"{count}"
        if not self.use_raw_counts:
            thresholds, keywords = self.block_thresholds
            threshold = np.digitize(count, thresholds, right=True)
            keyword = keywords[threshold]
        return f"{keyword} blocks", keyword

    def elevation_prompt(self, flattened_level: str, level: str):
        top_levels = level[:6]  # elevation 8 and up
        for t in top_levels:
            if "X" in t or "<" in t or ">" in t:
                return "high elevation", "high"
        return "low elevation", "low"

    def output_hidden(self, prompt: str, device: torch.device = torch.device("cpu")):
        # Reducing along the first dimension to get a 768 dimensional array
        return (
            self.feature_extraction(prompt, return_tensors="pt")[0]
            .mean(0)
            .to(device)
            .view(1, -1)
        )

    def dataset_statistics(self, dataset: MarioDataset):
        enemy_counts = []
        pipe_counts = []
        block_counts = []
        for i in range(len(dataset)):
            level, _ = dataset[i]
            str_level = self._flatten_level(view_level(level, dataset.tokenizer))

            enemy_count = self.count_enemies(str_level)
            pipe_count = self.count_pipes(str_level)
            block_count = self.count_blocks(str_level)

            enemy_counts.append(enemy_count)
            pipe_counts.append(pipe_count)
            block_counts.append(block_count)
        d = {"enemy": {}, "pipe": {}, "block": {}}

        d["enemy"] = stats.mstats.mquantiles(enemy_counts, [0.33, 0.66, 0.95])
        d["pipe"] = stats.mstats.mquantiles(pipe_counts, [0.33, 0.66, 0.95])
        d["block"] = stats.mstats.mquantiles(block_counts, [0.33, 0.66, 0.95])
        return d

    def __call__(
        self, level: torch.Tensor = None, sample_prompt: bool = False
    ) -> Union[str, torch.Tensor]:
        device: torch.device = torch.device("cpu")
        if not sample_prompt:
            if level is None:
                raise ValueError("Level must be provided if sample_prompt is not true!")
            str_level = view_level(level, self.level_tokenizer)
            flattened_level = self._flatten_level(str_level)

            pipe_prompt, _ = self.pipe_prompt(flattened_level, str_level)
            enemy_prompt, _ = self.enemy_prompt(flattened_level, str_level)
            block_prompt, _ = self.block_prompt(flattened_level, str_level)
            elevation_prompt, _ = self.elevation_prompt(flattened_level, str_level)
            device = level.device
        else:
            str_level = None
            pipe_prompt = random.choice(["no", "little", "some", "many"]) + " pipes"
            enemy_prompt = random.choice(["no", "little", "some", "many"]) + " enemies"
            block_prompt = (
                random.choice(["little", "little", "some", "many"]) + " blocks"
            )  # levels always have blocks
            elevation_prompt = (
                random.choice(["low", "high"]) + " elevation"
            )  # levels always have blocks

        prompt_dict = {
            "pipe": pipe_prompt,
            "enemy": enemy_prompt,
            "block": block_prompt,
            "elevation_prompt": elevation_prompt,
        }
        prompt = f"{pipe_prompt}, {enemy_prompt}, {block_prompt}, {elevation_prompt}"
        hidden = self.output_hidden(prompt, device=device)
        return prompt, hidden, prompt_dict, str_level