multimodalart's picture
MarioGPT first attempt
850b0e4
raw
history blame
No virus
5.23 kB
from typing import List, Optional
import numpy as np
import torch
from tqdm import tqdm
from transformers import (
AutoModelWithLMHead,
AutoTokenizer,
GPT2Model,
GPT2Tokenizer,
LogitsProcessorList,
PreTrainedModel,
PreTrainedTokenizer,
TemperatureLogitsWarper,
TopKLogitsWarper,
)
from mario_gpt.prompter import Prompter
PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
class MarioLM:
def __init__(
self,
lm: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
context_len: int = 700,
prompter: Optional[Prompter] = None,
):
self.context_len = context_len
self.lm = lm
if lm is None:
self.lm = self.load_pretrained_lm()
self.tokenizer = tokenizer
if tokenizer is None:
self.tokenizer = self.load_pretrained_tokenizer()
self.prompter = prompter
if prompter is None:
self.prompter = Prompter(self.tokenizer)
@property
def device(self):
return self.lm.device
def to(self, device: torch.device):
self.lm = self.lm.to(device)
return self
def load_pretrained_lm(self) -> GPT2Model:
print(f"Using {PRETRAINED_MODEL_PATH} model")
return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
def sample_step(
self,
seed: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temperature: float = 2.0,
):
lm = self.lm
logits_processor = LogitsProcessorList()
logits_warper = LogitsProcessorList(
[
TopKLogitsWarper(16), # number of characters
TemperatureLogitsWarper(temperature),
]
)
with torch.no_grad():
attention_mask = torch.ones_like(seed).to(seed.device)
input_ids = seed
out = lm(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
token_type_ids=None,
)
logits = out.logits.detach()
if len(logits.shape) == 2:
logits = logits.view(1, 1, -1)
next_token_logits = logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens, encoder_hidden_states
def sample(
self,
seed: Optional[torch.Tensor] = None,
prompts: Optional[List[str]] = None,
num_steps: int = 1,
temperature: float = 2.0,
encoder_hidden_states: torch.Tensor = None,
use_tqdm: bool = False,
):
context_len = self.context_len - 28
self.lm.eval()
with torch.no_grad():
if seed is None:
seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
out = seed.to(self.device)
if encoder_hidden_states is None:
if prompts is not None:
encoder_hidden_states = torch.stack(
[self.prompter.output_hidden(prompt) for prompt in prompts]
)
else:
encoder_hidden_states = torch.stack(
[
self.prompter(sample_prompt=True)[1]
for _ in range(seed.shape[0])
]
)
encoder_hidden_states = encoder_hidden_states.to(
self.device
) # b x 1 x hidden_dim
encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
if not use_tqdm:
bar = np.arange(num_steps)
else:
bar = tqdm(np.arange(num_steps))
with torch.no_grad():
for i in bar:
inp = out * 1
if len(out.shape) > 0 and out.shape[-1] > context_len:
diff = inp.shape[-1] % 14 # height of mario level
ctx = context_len + diff
inp = inp[:, -ctx:] * 1
next_tokens, encoder_hidden_states = self.sample_step(
inp,
encoder_hidden_states=encoder_hidden_states,
temperature=temperature,
)
out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
if use_tqdm:
bar.set_description(
f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
)
if use_tqdm:
bar.close()
self.lm.train()
return out