from __future__ import annotations from typing import List, Optional import numpy as np import torch from torch.utils.data import Dataset from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS DEFAULT_MODEL = "distilgpt2" def split_given_size(a, size): return np.split(a, np.arange(size, len(a), size)) def flip_and_transpose(arr: np.array, flip_first: bool = False): if arr.shape[-1] > 1: if flip_first: return np.flip(arr, -1).transpose() return np.flip(arr.transpose(), -1) return arr def join_list_of_list(str_lists): return ["".join(s) for s in str_lists] def characterize(str_lists): return [list(s) for s in str_lists] class MarioDataset(Dataset): def __init__( self, tokenizer: Optional[PreTrainedTokenizer] = None, level_string: Optional[str] = None, context_len: int = 700, height: int = 14, remove_start_end_tokens: bool = False, sample_all_indices: bool = False, ): if level_string is None: print( "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..." ) level_string = FULL_LEVEL_STR_WITH_PATHS elif ".txt" in level_string: with open(level_string, "r") as file: level_string = file.read() self.character_set = set(level_string) if "\n" in self.character_set: self.character_set.remove("\n") self.vocab_size = len(self.character_set) self.sample_all_indices = sample_all_indices def get_training_corpus(): yield list(level_string) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) self.tokenizer = tokenizer if getattr(tokenizer, "train_new_from_iterator", None) is not None: self.tokenizer = tokenizer.train_new_from_iterator( get_training_corpus(), 52000 ) elif getattr(tokenizer, "train_from_iterator", None) is not None: self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer) self.tokenizer = self.tokenizer.train_new_from_iterator( get_training_corpus(), self.vocab_size ) self.context_len = context_len self.height = height x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n")) self.input_ids = x["input_ids"].squeeze() self.attention_masks = x["attention_mask"].squeeze() if remove_start_end_tokens: self.input_ids = self.input_ids[1:-1] self.attention_masks = self.attention_masks[1:-1] self.indices = self.generate_indices() self.unique_tokens, self.unique_counts = self.input_ids.unique( return_counts=True ) self.weighted_unique_counts = ( 1.0 / self.unique_counts / torch.sum(self.unique_counts) ) self.token_dict = {} string_tokens = list(self.tokenizer.decode(self.unique_tokens)) for int_token, string_token in zip(self.unique_tokens, string_tokens): self.token_dict[string_token] = int_token def convert_level_to_tensor(self, level: List[str]): str_arr = flip_and_transpose(np.array(characterize(level))) str_arr = "".join(join_list_of_list(str_arr)) x = self.tokenizer(str_arr, return_tensors="pt") return x, str_arr def __len__(self): return self.indices.shape[0] def __getitem__(self, idx): indices = self.indices[idx] return self.input_ids[indices], self.attention_masks[indices] def generate_indices(self): out = [] for idx in range(self.input_ids.shape[0] - self.context_len): if idx % self.height == 0 or self.sample_all_indices: arange = torch.arange(idx, idx + self.context_len) out.append(arange) return torch.stack(out) def sample_indices(self, batch_size): out = [] for _ in range(batch_size): start_idx = np.random.randint(0, self.__len__() - self.context_len) indices = torch.arange(start_idx, start_idx + self.context_len) out.append(indices) return torch.stack(out) def __str__(self): str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"])) string = "\n".join( join_list_of_list(flip_and_transpose(np.array(str_list), True)) ) return string def generate_mask(self, mask_len: int, batch_size: int = 1): mask_token = self.tokenizer("").input_ids[1] ones = torch.ones((batch_size, mask_len)) return ones * mask_token def apply_mask(self, level, masked_indices, mask=None): if len(level.shape) == 1: level = level.unsqueeze(0) batch_size = level.shape[0] mask_len = masked_indices.shape[-1] if mask is None: mask = self.generate_mask(mask_len, batch_size) mask = mask.long().to(level.device) masked_level = level * torch.ones_like(level).to(level.device) masked_level[:, masked_indices] = mask return masked_level