File size: 5,229 Bytes
850b0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import List, Optional

import numpy as np
import torch
from tqdm import tqdm
from transformers import (
    AutoModelWithLMHead,
    AutoTokenizer,
    GPT2Model,
    GPT2Tokenizer,
    LogitsProcessorList,
    PreTrainedModel,
    PreTrainedTokenizer,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
)

from mario_gpt.prompter import Prompter

PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"


class MarioLM:
    def __init__(
        self,
        lm: Optional[PreTrainedModel] = None,
        tokenizer: Optional[PreTrainedTokenizer] = None,
        context_len: int = 700,
        prompter: Optional[Prompter] = None,
    ):
        self.context_len = context_len
        self.lm = lm

        if lm is None:
            self.lm = self.load_pretrained_lm()

        self.tokenizer = tokenizer
        if tokenizer is None:
            self.tokenizer = self.load_pretrained_tokenizer()

        self.prompter = prompter
        if prompter is None:
            self.prompter = Prompter(self.tokenizer)

    @property
    def device(self):
        return self.lm.device

    def to(self, device: torch.device):
        self.lm = self.lm.to(device)
        return self

    def load_pretrained_lm(self) -> GPT2Model:
        print(f"Using {PRETRAINED_MODEL_PATH} model")
        return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)

    def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
        print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
        return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)

    def sample_step(
        self,
        seed: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temperature: float = 2.0,
    ):
        lm = self.lm
        logits_processor = LogitsProcessorList()
        logits_warper = LogitsProcessorList(
            [
                TopKLogitsWarper(16),  # number of characters
                TemperatureLogitsWarper(temperature),
            ]
        )
        with torch.no_grad():
            attention_mask = torch.ones_like(seed).to(seed.device)
            input_ids = seed
            out = lm(
                input_ids=input_ids,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                token_type_ids=None,
            )
            logits = out.logits.detach()
            if len(logits.shape) == 2:
                logits = logits.view(1, 1, -1)
            next_token_logits = logits[:, -1, :]

            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)
            probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        return next_tokens, encoder_hidden_states

    def sample(
        self,
        seed: Optional[torch.Tensor] = None,
        prompts: Optional[List[str]] = None,
        num_steps: int = 1,
        temperature: float = 2.0,
        encoder_hidden_states: torch.Tensor = None,
        use_tqdm: bool = False,
    ):
        context_len = self.context_len - 28
        self.lm.eval()
        with torch.no_grad():
            if seed is None:
                seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
            out = seed.to(self.device)
            if encoder_hidden_states is None:
                if prompts is not None:
                    encoder_hidden_states = torch.stack(
                        [self.prompter.output_hidden(prompt) for prompt in prompts]
                    )
                else:
                    encoder_hidden_states = torch.stack(
                        [
                            self.prompter(sample_prompt=True)[1]
                            for _ in range(seed.shape[0])
                        ]
                    )
            encoder_hidden_states = encoder_hidden_states.to(
                self.device
            )  # b x 1 x hidden_dim
            encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
            if not use_tqdm:
                bar = np.arange(num_steps)
            else:
                bar = tqdm(np.arange(num_steps))
            with torch.no_grad():
                for i in bar:
                    inp = out * 1
                    if len(out.shape) > 0 and out.shape[-1] > context_len:
                        diff = inp.shape[-1] % 14  # height of mario level
                        ctx = context_len + diff
                        inp = inp[:, -ctx:] * 1
                    next_tokens, encoder_hidden_states = self.sample_step(
                        inp,
                        encoder_hidden_states=encoder_hidden_states,
                        temperature=temperature,
                    )
                    out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
                    if use_tqdm:
                        bar.set_description(
                            f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
                        )
            if use_tqdm:
                bar.close()
        self.lm.train()
        return out