| | |
| |
|
| | import logging |
| | from typing import ClassVar, Iterable |
| |
|
| | import numpy as np |
| | from scipy.fft import dct, idct |
| | from tokenizers import ByteLevelBPETokenizer |
| | from tokenizers.trainers import BpeTrainer |
| | from transformers import PreTrainedTokenizerFast |
| | from transformers.processing_utils import ProcessorMixin |
| |
|
| |
|
| | class ResidualFASTActionProcessor(ProcessorMixin): |
| | """ |
| | Residual FAST: intent + residual tokenization built on top of FAST's DCT+BPE scheme. |
| | |
| | Encodes an action chunk (B, T, D) into tokens by: |
| | 1) DCT over time axis |
| | 2) Split coeffs: |
| | intent = coeff[:k_intent, :] |
| | residual= coeff[k_intent:, :] |
| | 3) Quantize to ints |
| | 4) Convert to a string of characters via chr(int - min_token) |
| | 5) Wrap with special markers: <INTENT> ... <RESIDUAL> ... |
| | 6) BPE-tokenize the resulting string |
| | |
| | Decoding reverses the above. |
| | |
| | Notes: |
| | - Assumes input actions are already normalized to roughly [-1, 1] (same as FAST). |
| | - Uses a single BPE tokenizer to keep the interface identical to FAST. |
| | - Markers are special tokens so decode can reliably split streams. |
| | """ |
| |
|
| | attributes: ClassVar[list[str]] = ["bpe_tokenizer"] |
| | bpe_tokenizer_class: str = "AutoTokenizer" |
| |
|
| | INTENT_MARKER = "<INTENT>" |
| | RESIDUAL_MARKER = "<RESIDUAL>" |
| |
|
| | def __init__( |
| | self, |
| | bpe_tokenizer: PreTrainedTokenizerFast, |
| | *, |
| | k_intent: int = 5, |
| | scale: float = 10.0, |
| | vocab_size: int = 1024, |
| | min_token: int = 0, |
| | action_dim: int | None = None, |
| | time_horizon: int | None = None, |
| | ): |
| | self.k_intent = int(k_intent) |
| | self.scale = float(scale) |
| | self.vocab_size = int(vocab_size) |
| | self.min_token = int(min_token) |
| |
|
| | |
| | self.time_horizon = time_horizon |
| | self.action_dim = action_dim |
| | self.called_time_horizon = time_horizon |
| | self.called_action_dim = action_dim |
| |
|
| | |
| | self._ensure_special_tokens(bpe_tokenizer) |
| |
|
| | super().__init__(bpe_tokenizer) |
| |
|
| | @staticmethod |
| | def _ensure_special_tokens(tok: PreTrainedTokenizerFast) -> None: |
| | special = set(tok.all_special_tokens) |
| | to_add = [] |
| | if ResidualFASTActionProcessor.INTENT_MARKER not in special: |
| | to_add.append(ResidualFASTActionProcessor.INTENT_MARKER) |
| | if ResidualFASTActionProcessor.RESIDUAL_MARKER not in special: |
| | to_add.append(ResidualFASTActionProcessor.RESIDUAL_MARKER) |
| | if to_add: |
| | tok.add_special_tokens({"additional_special_tokens": to_add}) |
| |
|
| | def __call__(self, action_chunk: np.ndarray) -> list[list[int]]: |
| | """ |
| | action_chunk: np.ndarray with shape (T, D) or (B, T, D) |
| | returns: list of token-id lists, length B |
| | """ |
| | assert action_chunk.ndim <= 3, "Only up to 3 dims supported: [batch, timesteps, action_dim]" |
| | if action_chunk.ndim == 2: |
| | action_chunk = action_chunk[None, ...] |
| |
|
| | B, T, D = action_chunk.shape |
| | if self.k_intent < 0 or self.k_intent > T: |
| | raise ValueError(f"k_intent must be in [0, T]. Got k_intent={self.k_intent}, T={T}") |
| |
|
| | |
| | self.called_time_horizon = T |
| | self.called_action_dim = D |
| |
|
| | |
| | coeff = dct(action_chunk, axis=1, norm="ortho") |
| |
|
| | |
| | intent_coeff = coeff[:, : self.k_intent, :] |
| | residual_coeff = coeff[:, self.k_intent :, :] |
| |
|
| | |
| | intent_q = np.around(intent_coeff * self.scale).astype(int) |
| | residual_q = np.around(residual_coeff * self.scale).astype(int) |
| |
|
| | tokens: list[list[int]] = [] |
| | for b in range(B): |
| | |
| | intent_chars = "".join( |
| | map(chr, np.maximum(intent_q[b].flatten() - self.min_token, 0).astype(int)) |
| | ) |
| | residual_chars = "".join( |
| | map(chr, np.maximum(residual_q[b].flatten() - self.min_token, 0).astype(int)) |
| | ) |
| |
|
| | |
| | token_str = f"{self.INTENT_MARKER}{intent_chars}{self.RESIDUAL_MARKER}{residual_chars}" |
| |
|
| | |
| | ids = self.bpe_tokenizer(token_str, add_special_tokens=False)["input_ids"] |
| | tokens.append(ids) |
| |
|
| | return tokens |
| |
|
| | def decode( |
| | self, |
| | tokens: list[list[int]], |
| | *, |
| | time_horizon: int | None = None, |
| | action_dim: int | None = None, |
| | k_intent: int | None = None, |
| | ) -> np.ndarray: |
| | """ |
| | tokens: list of token-id lists (batch) |
| | returns: np.ndarray (B, T, D) |
| | """ |
| | self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon |
| | self.action_dim = action_dim or self.action_dim or self.called_action_dim |
| | K = int(k_intent) if k_intent is not None else self.k_intent |
| |
|
| | |
| | self.called_time_horizon = self.time_horizon |
| | self.called_action_dim = self.action_dim |
| |
|
| | assert self.time_horizon is not None and self.action_dim is not None, ( |
| | "Tokenizer not initialized. Call encode() once or pass time_horizon and action_dim." |
| | ) |
| |
|
| | T = int(self.time_horizon) |
| | D = int(self.action_dim) |
| | if K < 0 or K > T: |
| | raise ValueError(f"k_intent must be in [0, T]. Got k_intent={K}, T={T}") |
| |
|
| | decoded_actions = [] |
| | for token_ids in tokens: |
| | try: |
| | |
| | decoded = self.bpe_tokenizer.decode(token_ids, skip_special_tokens=False) |
| |
|
| | |
| | decoded = "".join(decoded.split()) |
| |
|
| | |
| | i0 = decoded.find(self.INTENT_MARKER) |
| | i1 = decoded.find(self.RESIDUAL_MARKER) |
| | if i0 == -1 or i1 == -1 or i1 < i0: |
| | raise ValueError("Missing or misordered <INTENT>/<RESIDUAL> markers in decoded string.") |
| |
|
| | intent_str = decoded[i0 + len(self.INTENT_MARKER) : i1] |
| | residual_str = decoded[i1 + len(self.RESIDUAL_MARKER) :] |
| |
|
| | |
| | intent_vals = np.array(list(map(ord, intent_str)), dtype=int) + self.min_token |
| | residual_vals = np.array(list(map(ord, residual_str)), dtype=int) + self.min_token |
| |
|
| | |
| | if intent_vals.size != K * D: |
| | raise ValueError(f"Intent size mismatch: got {intent_vals.size}, expected {K*D}") |
| | if residual_vals.size != (T - K) * D: |
| | raise ValueError(f"Residual size mismatch: got {residual_vals.size}, expected {(T-K)*D}") |
| |
|
| | intent_q = intent_vals.reshape(K, D) |
| | residual_q = residual_vals.reshape(T - K, D) |
| |
|
| | |
| | coeff_q = np.zeros((T, D), dtype=float) |
| | coeff_q[:K, :] = intent_q |
| | coeff_q[K:, :] = residual_q |
| |
|
| | |
| | action = idct(coeff_q / self.scale, axis=0, norm="ortho") |
| | except Exception as e: |
| | print(f"[ResidualFAST] Error decoding tokens: {e}") |
| | print(f"[ResidualFAST] Tokens: {token_ids}") |
| | action = np.zeros((T, D), dtype=float) |
| |
|
| | decoded_actions.append(action) |
| |
|
| | return np.stack(decoded_actions, axis=0) |
| |
|
| | @classmethod |
| | def fit( |
| | cls, |
| | action_data: list[np.ndarray] | np.ndarray, |
| | *, |
| | k_intent: int = 5, |
| | scale: float = 10.0, |
| | vocab_size: int = 1024, |
| | time_horizon: int | None = None, |
| | action_dim: int | None = None, |
| | ) -> "ResidualFASTActionProcessor": |
| | """ |
| | Train the internal BPE tokenizer on Residual FAST strings. |
| | |
| | action_data can be: |
| | - list of arrays, each (T, D) |
| | - or a single array (N, T, D) |
| | |
| | NOTE: |
| | - We keep the FAST alphabet trick: all possible quantized values are present in initial_alphabet. |
| | - We reserve room in vocab_size for the special marker tokens. |
| | """ |
| | if isinstance(action_data, np.ndarray): |
| | assert action_data.ndim == 3, "If passing np.ndarray, expected shape (N, T, D)." |
| | chunks = [action_data[i] for i in range(action_data.shape[0])] |
| | else: |
| | chunks = action_data |
| |
|
| | if len(chunks) == 0: |
| | raise ValueError("Empty action_data passed to fit().") |
| |
|
| | |
| | Ds = [c.shape[1] for c in chunks] |
| | if len(set(Ds)) != 1 and action_dim is None: |
| | raise ValueError("Varying action_dim in fit() data. Pass action_dim=... or standardize D.") |
| | D = action_dim if action_dim is not None else Ds[0] |
| |
|
| | |
| | all_q_vals = [] |
| | strings = [] |
| |
|
| | for a in chunks: |
| | assert a.ndim == 2, "Each chunk must be (T, D)." |
| | T, d = a.shape |
| | if d != D: |
| | raise ValueError(f"Chunk action_dim={d} != expected D={D}.") |
| | if k_intent < 0 or k_intent > T: |
| | raise ValueError(f"k_intent must be in [0, T]. Got k_intent={k_intent}, T={T}") |
| |
|
| | coeff = dct(a, axis=0, norm="ortho") |
| | intent = coeff[:k_intent, :] |
| | residual = coeff[k_intent:, :] |
| |
|
| | |
| | intent_q = np.around(intent * scale).astype(int) |
| | residual_q = np.around(residual * scale).astype(int) |
| |
|
| | all_q_vals.append(intent_q.flatten()) |
| | all_q_vals.append(residual_q.flatten()) |
| |
|
| | all_q = np.concatenate(all_q_vals, axis=0) |
| | max_token = int(all_q.max()) |
| | min_token = int(all_q.min()) |
| |
|
| | |
| | min_vocab_size = max_token - min_token |
| | n_special = 2 |
| | required_vocab = (max_token - min_token + 1) + n_special |
| | if required_vocab > vocab_size: |
| | raise AssertionError( |
| | f"vocab_size={vocab_size} too small. Need >= (range+special) = {required_vocab} " |
| | f"(range={max_token-min_token+1}, special={n_special})." |
| | ) |
| |
|
| | if (max_token - min_token + 1) + 100 > vocab_size: |
| | logging.warning( |
| | "Initial alphabet size is close to vocab_size. Consider increasing vocab_size " |
| | "for better BPE compression." |
| | ) |
| |
|
| | |
| | def _token_iter() -> Iterable[str]: |
| | for a in chunks: |
| | T, d = a.shape |
| | coeff = dct(a, axis=0, norm="ortho") |
| |
|
| | intent = coeff[:k_intent, :] |
| | residual = coeff[k_intent:, :] |
| |
|
| | intent_q = (np.around(intent * scale) - min_token).astype(int) |
| | residual_q = (np.around(residual * scale) - min_token).astype(int) |
| |
|
| | intent_str = "".join(map(chr, intent_q.flatten())) |
| | residual_str = "".join(map(chr, residual_q.flatten())) |
| |
|
| | yield f"{cls.INTENT_MARKER}{intent_str}{cls.RESIDUAL_MARKER}{residual_str}" |
| |
|
| | |
| | bpe = ByteLevelBPETokenizer() |
| |
|
| | |
| | alphabet = [chr(i) for i in range(max_token - min_token + 1)] |
| |
|
| | trainer = BpeTrainer( |
| | vocab_size=vocab_size, |
| | min_frequency=2, |
| | show_progress=True, |
| | special_tokens=[cls.INTENT_MARKER, cls.RESIDUAL_MARKER], |
| | initial_alphabet=alphabet, |
| | max_token_length=10000, |
| | ) |
| |
|
| | |
| | bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) |
| |
|
| | hf_tok = PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False) |
| | |
| | cls._ensure_special_tokens(hf_tok) |
| |
|
| | return cls( |
| | hf_tok, |
| | k_intent=k_intent, |
| | scale=scale, |
| | vocab_size=vocab_size, |
| | min_token=min_token, |
| | time_horizon=time_horizon, |
| | action_dim=action_dim, |
| | ) |
| |
|