File size: 8,138 Bytes
2b21555 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import itertools
import logging
from typing import ClassVar
import numpy as np
from scipy.fft import dct, idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
class FASTProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
bpe_tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
bpe_tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(bpe_tokenizer)
def __call__(self, action_chunk: np.array) -> np.array:
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert (
self.time_horizon is not None and self.action_dim is not None
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
if decoded_dct_coeff.size > (size := (self.time_horizon * self.action_dim)):
print(f"Error decoding tokens. Truncating")
decoded_dct_coeff = decoded_dct_coeff[:size]
elif decoded_dct_coeff.size < size:
print(f"Error decoding tokens. Padding with zeros")
decoded_dct_coeff = np.concatenate(
[
decoded_dct_coeff,
np.zeros(size - decoded_dct_coeff.size, dtype=decoded_dct_coeff.dtype),
]
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert decoded_dct_coeff.shape == (
self.time_horizon,
self.action_dim,
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
except Exception as e: # noqa: BLE001
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.ndarray],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "FASTProcessor":
if action_data[0].ndim == 2:
# Run DCT over all inputs
dct_tokens = [ # each of shape [num_control_points * control_components]
dct(a, axis=0, norm="ortho").flatten() for a in action_data
]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rtokens = np.around(tokens * scale) - min_token
rtokens = rtokens.astype(int)
string = "".join(map(chr, rtokens))
yield string
token_iter = _token_iter()
elif action_data[0].ndim == 3:
# Run DCT over all inputs
dct_tokens: list[np.ndarray] = [ # each of shape [B, num_control_points, control_components]
dct(a, axis=1, norm="ortho") for a in action_data
]
# Quantize and find min token
rounded_tokens: list[np.ndarray] = [ # each of shape [B, num_control_points, control_components]
np.around(tokens * scale) for tokens in dct_tokens
]
max_token = int(np.max([tokens.max() for tokens in rounded_tokens]))
min_token = int(np.min([tokens.min() for tokens in rounded_tokens]))
min_vocab_size = max_token - min_token
# Convert to char tokens
np_chr = np.frompyfunc(chr, 1, 1)
char_tokens = [ # each of shape [B, num_control_points * control_components]
np_chr((tokens - min_token).astype(np.int64).reshape(tokens.shape[0], -1)).sum(-1)
for tokens in rounded_tokens
]
rounded_tokens = None
token_iter = itertools.chain(iter(batch_tokens) for batch_tokens in char_tokens)
else:
raise NotImplementedError(action_data[0].shape)
assert (
min_vocab_size <= vocab_size
), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(token_iter, trainer=trainer) # noqa: SLF001
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)
|