multimodalart's picture
MarioGPT first attempt
850b0e4
raw
history blame contribute delete
No virus
5.23 kB
from typing import List, Optional
import numpy as np
import torch
from tqdm import tqdm
from transformers import (
AutoModelWithLMHead,
AutoTokenizer,
GPT2Model,
GPT2Tokenizer,
LogitsProcessorList,
PreTrainedModel,
PreTrainedTokenizer,
TemperatureLogitsWarper,
TopKLogitsWarper,
)
from mario_gpt.prompter import Prompter
PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
class MarioLM:
def __init__(
self,
lm: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
context_len: int = 700,
prompter: Optional[Prompter] = None,
):
self.context_len = context_len
self.lm = lm
if lm is None:
self.lm = self.load_pretrained_lm()
self.tokenizer = tokenizer
if tokenizer is None:
self.tokenizer = self.load_pretrained_tokenizer()
self.prompter = prompter
if prompter is None:
self.prompter = Prompter(self.tokenizer)
@property
def device(self):
return self.lm.device
def to(self, device: torch.device):
self.lm = self.lm.to(device)
return self
def load_pretrained_lm(self) -> GPT2Model:
print(f"Using {PRETRAINED_MODEL_PATH} model")
return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
def sample_step(
self,
seed: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temperature: float = 2.0,
):
lm = self.lm
logits_processor = LogitsProcessorList()
logits_warper = LogitsProcessorList(
[
TopKLogitsWarper(16), # number of characters
TemperatureLogitsWarper(temperature),
]
)
with torch.no_grad():
attention_mask = torch.ones_like(seed).to(seed.device)
input_ids = seed
out = lm(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
token_type_ids=None,
)
logits = out.logits.detach()
if len(logits.shape) == 2:
logits = logits.view(1, 1, -1)
next_token_logits = logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens, encoder_hidden_states
def sample(
self,
seed: Optional[torch.Tensor] = None,
prompts: Optional[List[str]] = None,
num_steps: int = 1,
temperature: float = 2.0,
encoder_hidden_states: torch.Tensor = None,
use_tqdm: bool = False,
):
context_len = self.context_len - 28
self.lm.eval()
with torch.no_grad():
if seed is None:
seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
out = seed.to(self.device)
if encoder_hidden_states is None:
if prompts is not None:
encoder_hidden_states = torch.stack(
[self.prompter.output_hidden(prompt) for prompt in prompts]
)
else:
encoder_hidden_states = torch.stack(
[
self.prompter(sample_prompt=True)[1]
for _ in range(seed.shape[0])
]
)
encoder_hidden_states = encoder_hidden_states.to(
self.device
) # b x 1 x hidden_dim
encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
if not use_tqdm:
bar = np.arange(num_steps)
else:
bar = tqdm(np.arange(num_steps))
with torch.no_grad():
for i in bar:
inp = out * 1
if len(out.shape) > 0 and out.shape[-1] > context_len:
diff = inp.shape[-1] % 14 # height of mario level
ctx = context_len + diff
inp = inp[:, -ctx:] * 1
next_tokens, encoder_hidden_states = self.sample_step(
inp,
encoder_hidden_states=encoder_hidden_states,
temperature=temperature,
)
out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
if use_tqdm:
bar.set_description(
f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
)
if use_tqdm:
bar.close()
self.lm.train()
return out