Spaces:
Runtime error
Runtime error
File size: 5,336 Bytes
850b0e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from __future__ import annotations
from typing import List, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS
DEFAULT_MODEL = "distilgpt2"
def split_given_size(a, size):
return np.split(a, np.arange(size, len(a), size))
def flip_and_transpose(arr: np.array, flip_first: bool = False):
if arr.shape[-1] > 1:
if flip_first:
return np.flip(arr, -1).transpose()
return np.flip(arr.transpose(), -1)
return arr
def join_list_of_list(str_lists):
return ["".join(s) for s in str_lists]
def characterize(str_lists):
return [list(s) for s in str_lists]
class MarioDataset(Dataset):
def __init__(
self,
tokenizer: Optional[PreTrainedTokenizer] = None,
level_string: Optional[str] = None,
context_len: int = 700,
height: int = 14,
remove_start_end_tokens: bool = False,
sample_all_indices: bool = False,
):
if level_string is None:
print(
"No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..."
)
level_string = FULL_LEVEL_STR_WITH_PATHS
elif ".txt" in level_string:
with open(level_string, "r") as file:
level_string = file.read()
self.character_set = set(level_string)
if "\n" in self.character_set:
self.character_set.remove("\n")
self.vocab_size = len(self.character_set)
self.sample_all_indices = sample_all_indices
def get_training_corpus():
yield list(level_string)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
self.tokenizer = tokenizer
if getattr(tokenizer, "train_new_from_iterator", None) is not None:
self.tokenizer = tokenizer.train_new_from_iterator(
get_training_corpus(), 52000
)
elif getattr(tokenizer, "train_from_iterator", None) is not None:
self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
self.tokenizer = self.tokenizer.train_new_from_iterator(
get_training_corpus(), self.vocab_size
)
self.context_len = context_len
self.height = height
x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n"))
self.input_ids = x["input_ids"].squeeze()
self.attention_masks = x["attention_mask"].squeeze()
if remove_start_end_tokens:
self.input_ids = self.input_ids[1:-1]
self.attention_masks = self.attention_masks[1:-1]
self.indices = self.generate_indices()
self.unique_tokens, self.unique_counts = self.input_ids.unique(
return_counts=True
)
self.weighted_unique_counts = (
1.0 / self.unique_counts / torch.sum(self.unique_counts)
)
self.token_dict = {}
string_tokens = list(self.tokenizer.decode(self.unique_tokens))
for int_token, string_token in zip(self.unique_tokens, string_tokens):
self.token_dict[string_token] = int_token
def convert_level_to_tensor(self, level: List[str]):
str_arr = flip_and_transpose(np.array(characterize(level)))
str_arr = "".join(join_list_of_list(str_arr))
x = self.tokenizer(str_arr, return_tensors="pt")
return x, str_arr
def __len__(self):
return self.indices.shape[0]
def __getitem__(self, idx):
indices = self.indices[idx]
return self.input_ids[indices], self.attention_masks[indices]
def generate_indices(self):
out = []
for idx in range(self.input_ids.shape[0] - self.context_len):
if idx % self.height == 0 or self.sample_all_indices:
arange = torch.arange(idx, idx + self.context_len)
out.append(arange)
return torch.stack(out)
def sample_indices(self, batch_size):
out = []
for _ in range(batch_size):
start_idx = np.random.randint(0, self.__len__() - self.context_len)
indices = torch.arange(start_idx, start_idx + self.context_len)
out.append(indices)
return torch.stack(out)
def __str__(self):
str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"]))
string = "\n".join(
join_list_of_list(flip_and_transpose(np.array(str_list), True))
)
return string
def generate_mask(self, mask_len: int, batch_size: int = 1):
mask_token = self.tokenizer("<mask>").input_ids[1]
ones = torch.ones((batch_size, mask_len))
return ones * mask_token
def apply_mask(self, level, masked_indices, mask=None):
if len(level.shape) == 1:
level = level.unsqueeze(0)
batch_size = level.shape[0]
mask_len = masked_indices.shape[-1]
if mask is None:
mask = self.generate_mask(mask_len, batch_size)
mask = mask.long().to(level.device)
masked_level = level * torch.ones_like(level).to(level.device)
masked_level[:, masked_indices] = mask
return masked_level
|