|
|
from typing import List, Tuple, Union |
|
|
|
|
|
import einops |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoModel, PreTrainedModel |
|
|
from vector_quantize_pytorch import VectorQuantize |
|
|
|
|
|
from .configuration_actioncodec import ActionCodecConfig |
|
|
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder |
|
|
from .rvq import ResidualVectorQuantize |
|
|
|
|
|
|
|
|
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]: |
|
|
if arr.shape[0] == 0: |
|
|
return [] |
|
|
|
|
|
b, n = arr.shape |
|
|
|
|
|
is_nonzero = arr != 0 |
|
|
flipped_mask = np.flip(is_nonzero, axis=1) |
|
|
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1) |
|
|
any_nonzero_in_row = is_nonzero.any(axis=1) |
|
|
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row |
|
|
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class ActionCodec(PreTrainedModel): |
|
|
"""ActionCodec: A neural codec for encoding and decoding robot action sequences. |
|
|
|
|
|
This model uses a Perceiver-based encoder-decoder architecture with vector quantization |
|
|
to convert continuous action sequences into discrete token sequences. It supports |
|
|
multiple robot embodiments with different action dimensions and control frequencies. |
|
|
|
|
|
The model supports two vector quantization types: |
|
|
- VQ (Vector Quantization): Single quantizer |
|
|
- RVQ (Residual Vector Quantization): Multiple quantizers for hierarchical encoding |
|
|
|
|
|
Key features: |
|
|
- Multi-embodiment support: Handle different robots with varying action dimensions |
|
|
- Dynamic expansion: Add new robot configurations without retraining |
|
|
- Flexible input/output: Support numpy arrays and torch tensors |
|
|
""" |
|
|
|
|
|
config_class = ActionCodecConfig |
|
|
|
|
|
def __init__(self, config: ActionCodecConfig): |
|
|
"""Initialize the ActionCodec model. |
|
|
|
|
|
Args: |
|
|
config (ActionCodecConfig): Model configuration containing hyperparameters |
|
|
and embodiment configurations. |
|
|
|
|
|
Raises: |
|
|
ValueError: If configuration parameters are invalid. |
|
|
NotImplementedError: If the specified VQ type is not supported. |
|
|
""" |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if config.n_tokens % config.n_quantizers != 0: |
|
|
raise ValueError(f"n_tokens ({config.n_tokens}) must be divisible by n_quantizers ({config.n_quantizers})") |
|
|
|
|
|
if config.n_quantizers < 1: |
|
|
raise ValueError(f"n_quantizers must be at least 1, got {config.n_quantizers}") |
|
|
|
|
|
if config.vq_codebook_size < 1: |
|
|
raise ValueError(f"vq_codebook_size must be at least 1, got {config.vq_codebook_size}") |
|
|
|
|
|
if config.z_dim < 1: |
|
|
raise ValueError(f"z_dim must be at least 1, got {config.z_dim}") |
|
|
|
|
|
if not isinstance(config.embodiment_config, dict) or len(config.embodiment_config) == 0: |
|
|
raise ValueError( |
|
|
"embodiment_config must be a non-empty dictionary mapping embodiment names to configurations" |
|
|
) |
|
|
|
|
|
self.default_embodiment_id = 0 |
|
|
|
|
|
|
|
|
self.encoder = PerceiverEncoder(config) |
|
|
self.decoder = PerceiverDecoder(config) |
|
|
|
|
|
|
|
|
if config.vq_type == "vq": |
|
|
if config.n_quantizers != 1: |
|
|
raise ValueError( |
|
|
f"VQ type requires n_quantizers=1, got {config.n_quantizers}. Use RVQ type for multiple quantizers." |
|
|
) |
|
|
self.vq = VectorQuantize( |
|
|
dim=config.z_dim, |
|
|
codebook_size=config.vq_codebook_size, |
|
|
commitment_weight=config.vq_commitment_weight, |
|
|
decay=config.vq_decay, |
|
|
kmeans_init=config.vq_kmeans_init, |
|
|
threshold_ema_dead_code=config.vq_threshold_ema_dead_code, |
|
|
rotation_trick=False, |
|
|
straight_through=True, |
|
|
) |
|
|
elif config.vq_type == "rvq": |
|
|
if config.n_quantizers < 2: |
|
|
raise ValueError( |
|
|
f"RVQ type requires n_quantizers >= 2, got {config.n_quantizers}. Use VQ type for single quantizer." |
|
|
) |
|
|
self.vq = ResidualVectorQuantize( |
|
|
dim=config.z_dim, |
|
|
n_codebooks=config.n_quantizers, |
|
|
codebook_size=config.vq_codebook_size, |
|
|
codebook_dim=config.z_dim, |
|
|
quantizer_dropout=config.vq_quantizer_dropout, |
|
|
commitment=config.vq_commitment_weight, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError(f"VQ type '{config.vq_type}' not implemented. Supported types: 'vq', 'rvq'") |
|
|
|
|
|
|
|
|
self.vocab_size = config.vq_codebook_size |
|
|
self.num_quantizers = config.n_quantizers |
|
|
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers |
|
|
|
|
|
def expand_embodiment(self, embodiment_config: dict): |
|
|
"""Dynamically expand the model to support new robot embodiments. |
|
|
|
|
|
This method allows adding new robot configurations to the codec without retraining |
|
|
the entire model. It updates the encoder and decoder to handle the new action dimensions |
|
|
and frequencies while preserving existing functionality for previously configured robots. |
|
|
|
|
|
Args: |
|
|
embodiment_config (dict): Dictionary mapping embodiment names to their configurations. |
|
|
Each configuration should be a dict with keys: |
|
|
- "action_dim" (int): Action dimensionality for this embodiment. |
|
|
- "freq" (float): Control frequency in Hz. |
|
|
- "duration" (float): Default action sequence duration in seconds. |
|
|
- "description" (str, optional): Human-readable description. |
|
|
|
|
|
Example: |
|
|
{ |
|
|
"robot_B": { |
|
|
"action_dim": 10, |
|
|
"freq": 20, |
|
|
"duration": 1.0, |
|
|
"description": "10-dim robot at 20Hz" |
|
|
} |
|
|
} |
|
|
|
|
|
Returns: |
|
|
ActionCodec: Returns self for method chaining. |
|
|
|
|
|
Note: |
|
|
- New embodiment keys must not already exist in the current configuration. |
|
|
- The model will automatically update max_action_dim if the new embodiment |
|
|
has a larger action dimension. |
|
|
- Existing embodiments will continue to work with their original configurations. |
|
|
""" |
|
|
if not isinstance(embodiment_config, dict): |
|
|
raise TypeError(f"embodiment_config must be a dict, got {type(embodiment_config)}") |
|
|
if len(embodiment_config) == 0: |
|
|
raise ValueError("embodiment_config cannot be empty") |
|
|
|
|
|
|
|
|
overlapping_keys = set(embodiment_config.keys()) & set(self.config.embodiment_config.keys()) |
|
|
if overlapping_keys: |
|
|
raise ValueError(f"The following embodiment keys already exist and cannot be redefined: {overlapping_keys}") |
|
|
|
|
|
self.encoder.expand_embodiment(embodiment_config) |
|
|
self.decoder.expand_embodiment(embodiment_config) |
|
|
self.config.embodiment_config.update(embodiment_config) |
|
|
return self |
|
|
|
|
|
def _encode( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
embodiment_ids: torch.Tensor | int | None = None, |
|
|
padding_mask: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
"""Encode action sequences into latent representations. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim). |
|
|
Assumes that the action dimension is zero-padded to the max action dimension. |
|
|
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. |
|
|
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). |
|
|
If int, the same embodiment ID is repeated for all sequences in the batch. |
|
|
It specifies the embodiment to encode. |
|
|
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. |
|
|
It is used to mask the padding tokens on `seq_len` dimension. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim). |
|
|
""" |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
z_e = self.encoder(x, embodiment_ids, padding_mask) |
|
|
return z_e |
|
|
|
|
|
def _quantize( |
|
|
self, z_e: torch.Tensor, return_perplexity: bool = True |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]: |
|
|
"""Quantize encoded representations using vector quantization. |
|
|
|
|
|
Args: |
|
|
z_e (torch.Tensor): Encoded latent representations to quantize. |
|
|
Shape: (b, n_tokens_per_quantizer, z_dim). |
|
|
return_perplexity (bool, optional): Whether to compute and return perplexity. |
|
|
Defaults to True. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]: |
|
|
A tuple containing: |
|
|
- z_q (torch.Tensor): Quantized representations. |
|
|
Shape: (b, n_tokens_per_quantizer, z_dim). |
|
|
- indices (torch.Tensor): Quantization indices. |
|
|
Shape: (b, n_tokens_per_quantizer) for VQ or (b, n_tokens_per_quantizer, n_quantizers) for RVQ. |
|
|
- perplexity (Union[float, List[float]]): Codebook perplexity. |
|
|
Float for single quantizer, List[float] for multiple quantizers. |
|
|
- commit_loss (torch.Tensor): Commitment loss scalar tensor. |
|
|
""" |
|
|
if isinstance(self.vq, ResidualVectorQuantize): |
|
|
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e) |
|
|
commit_loss = commitment_loss.mean() + codebook_loss.mean() |
|
|
elif isinstance(self.vq, VectorQuantize): |
|
|
z_q, indices, commit_loss = self.vq(z_e) |
|
|
else: |
|
|
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented") |
|
|
|
|
|
if return_perplexity: |
|
|
if len(indices.size()) < 3: |
|
|
indices = indices.unsqueeze(-1) |
|
|
perplexity = [] |
|
|
for k in range(indices.size(-1)): |
|
|
this_indices = indices[:, :, k] |
|
|
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size) |
|
|
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: |
|
|
torch.distributed.all_reduce(indices_count) |
|
|
this_avg_probs = indices_count.float() / indices_count.sum() |
|
|
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item())) |
|
|
else: |
|
|
perplexity = 0 |
|
|
|
|
|
return z_q, indices, perplexity, commit_loss |
|
|
|
|
|
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor: |
|
|
"""Dequantize token indices back to continuous latent representations. |
|
|
|
|
|
Args: |
|
|
indices (torch.Tensor): Quantization indices. Shape depends on quantizer type: |
|
|
- For VQ: (b, n_tokens) or (b, n_tokens, 1) |
|
|
- For RVQ: (b, n_tokens_per_quantizer, n_quantizers) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Dequantized latent representations. |
|
|
Shape: (b, n_tokens_per_quantizer, z_dim) |
|
|
""" |
|
|
if self.num_quantizers == 1: |
|
|
if len(indices.size()) == 3: |
|
|
indices = indices.squeeze(-1) |
|
|
if isinstance(self.vq, ResidualVectorQuantize): |
|
|
z_q = self.vq.from_codes(indices)[0] |
|
|
elif isinstance(self.vq, VectorQuantize): |
|
|
z_q = self.vq.get_output_from_indices(indices) |
|
|
else: |
|
|
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented in _dequantize") |
|
|
return z_q |
|
|
|
|
|
def _decode( |
|
|
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Decode quantized latent representations into action sequences. |
|
|
|
|
|
Args: |
|
|
z_q (torch.Tensor): Quantized latent representations. |
|
|
Shape: (b, n_tokens_per_quantizer, z_dim). |
|
|
embodiment_ids (Union[torch.Tensor, int, None], optional): Embodiment IDs. |
|
|
Shape: (b,) if tensor. If int, the same embodiment ID is used for all |
|
|
sequences. Defaults to None, which uses `self.default_embodiment_id`. |
|
|
durations (torch.Tensor | None, optional): Duration of each action sequence in seconds. |
|
|
Shape: (b,). If None, uses default duration from embodiment_config. |
|
|
Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
|
|
- x_recon (torch.Tensor): Reconstructed action sequences. |
|
|
Shape: (b, seq_len, max_action_dim). |
|
|
- padding_mask (torch.Tensor): Padding mask indicating valid timesteps. |
|
|
Shape: (b, seq_len), where True indicates valid timesteps. |
|
|
""" |
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations) |
|
|
return x_recon, padding_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode( |
|
|
self, |
|
|
x: Union[np.ndarray, torch.Tensor], |
|
|
embodiment_ids: Union[List[int], int, None] = None, |
|
|
padding_mask: Union[List[bool], np.ndarray, torch.Tensor, None] = None, |
|
|
**kwargs, |
|
|
) -> List[List[int]]: |
|
|
"""Encode action sequences into latent representations (token indices). |
|
|
|
|
|
This method converts action sequences into discrete token indices using the encoder |
|
|
and vector quantizer. The input can be either a numpy array or torch tensor. |
|
|
|
|
|
Args: |
|
|
x (Union[np.ndarray, torch.Tensor]): Action sequences to encode. |
|
|
Shape: (b, seq_len, max_action_dim). |
|
|
Assumes that the action dimension is zero-padded to the max action dimension. |
|
|
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and |
|
|
padded to the max sequence length. |
|
|
embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs. |
|
|
Shape: (b,) if list. If int, the same embodiment ID is repeated for all |
|
|
sequences in the batch. It specifies the embodiment to encode. |
|
|
Defaults to None, which uses `self.default_embodiment_id`. |
|
|
padding_mask (Union[List[bool], np.ndarray, torch.Tensor, None], optional): |
|
|
Padding mask, where `False` values indicate padding. Shape: (b, seq_len). |
|
|
Defaults to None. It is used to mask the padding tokens on `seq_len` dimension. |
|
|
**kwargs: Additional keyword arguments (currently unused, reserved for future use). |
|
|
|
|
|
Returns: |
|
|
List[List[int]]: List of token sequences. Shape: (b, n_tokens), where n_tokens |
|
|
is determined by the model configuration (typically `config.n_tokens`). |
|
|
|
|
|
Raises: |
|
|
ValueError: If input shapes are invalid or incompatible with the model configuration. |
|
|
TypeError: If input types are not supported. |
|
|
|
|
|
Examples: |
|
|
>>> import numpy as np |
|
|
>>> # Using numpy array |
|
|
>>> x = np.random.randn(2, 10, 7).astype(np.float32) |
|
|
>>> tokens = model.encode(x, embodiment_ids=[0, 0]) |
|
|
>>> # Using torch tensor |
|
|
>>> x_tensor = torch.randn(2, 10, 7) |
|
|
>>> tokens = model.encode(x_tensor, embodiment_ids=[0, 0]) |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
|
|
|
if isinstance(x, np.ndarray): |
|
|
if x.ndim != 3: |
|
|
raise ValueError( |
|
|
f"Expected 3D input array (batch, seq_len, action_dim), got {x.ndim}D array with shape {x.shape}" |
|
|
) |
|
|
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device) |
|
|
elif isinstance(x, torch.Tensor): |
|
|
if x.ndim != 3: |
|
|
raise ValueError( |
|
|
f"Expected 3D tensor (batch, seq_len, action_dim), got {x.ndim}D tensor with shape {x.shape}" |
|
|
) |
|
|
x_tensor = x.to(dtype=self.dtype, device=self.device) |
|
|
else: |
|
|
raise TypeError(f"Input x must be numpy.ndarray or torch.Tensor, got {type(x)}") |
|
|
|
|
|
|
|
|
batch_size = x_tensor.shape[0] |
|
|
if batch_size == 0: |
|
|
raise ValueError("Batch size must be at least 1") |
|
|
|
|
|
|
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
if isinstance(embodiment_ids, int): |
|
|
if not 0 <= embodiment_ids < len(self.config.embodiment_config): |
|
|
raise ValueError( |
|
|
f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). " |
|
|
f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}" |
|
|
) |
|
|
embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device) |
|
|
elif isinstance(embodiment_ids, list): |
|
|
if len(embodiment_ids) != batch_size: |
|
|
raise ValueError( |
|
|
f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})" |
|
|
) |
|
|
for eid in embodiment_ids: |
|
|
if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config): |
|
|
raise ValueError( |
|
|
f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})" |
|
|
) |
|
|
embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
|
|
else: |
|
|
raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}") |
|
|
|
|
|
|
|
|
padding_mask_tensor = None |
|
|
if padding_mask is not None: |
|
|
if isinstance(padding_mask, (list, np.ndarray)): |
|
|
padding_mask_tensor = torch.tensor(padding_mask, dtype=torch.bool, device=self.device) |
|
|
elif isinstance(padding_mask, torch.Tensor): |
|
|
padding_mask_tensor = padding_mask.to(dtype=torch.bool, device=self.device) |
|
|
else: |
|
|
raise TypeError( |
|
|
f"padding_mask must be List[bool], np.ndarray, torch.Tensor, or None, got {type(padding_mask)}" |
|
|
) |
|
|
if padding_mask_tensor.shape != (batch_size, x_tensor.shape[1]): |
|
|
raise ValueError( |
|
|
f"padding_mask shape {padding_mask_tensor.shape} does not match expected shape " |
|
|
f"({batch_size}, {x_tensor.shape[1]})" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
z_e = self._encode(x_tensor, embodiment_ids_tensor, padding_mask_tensor) |
|
|
_, indices, _, _ = self._quantize(z_e, return_perplexity=False) |
|
|
|
|
|
|
|
|
if len(indices.size()) > 2: |
|
|
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu() |
|
|
else: |
|
|
codes_list = indices.cpu() |
|
|
|
|
|
codes_list = codes_list.tolist() |
|
|
return codes_list |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode( |
|
|
self, |
|
|
tokens: Union[List[List[int]], np.ndarray, torch.Tensor], |
|
|
embodiment_ids: Union[List[int], int, None] = None, |
|
|
durations: Union[List[float], np.ndarray, torch.Tensor, None] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Decode token sequences into action sequences. |
|
|
|
|
|
This method reconstructs action sequences from discrete token indices using the |
|
|
vector quantizer and decoder. The input tokens can be a list of lists, numpy array, |
|
|
or torch tensor. |
|
|
|
|
|
Args: |
|
|
tokens (Union[List[List[int]], np.ndarray, torch.Tensor]): Token sequences to decode. |
|
|
Shape: (b, n_tokens), where n_tokens must be divisible by `n_tokens_per_quantizer`. |
|
|
For RVQ, tokens are interleaved: [q0_t0, q1_t0, ..., qN_t0, q0_t1, ...]. |
|
|
embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs. |
|
|
Shape: (b,) if list. If int, the same embodiment ID is repeated for all |
|
|
sequences in the batch. It specifies the embodiment to decode. |
|
|
Defaults to None, which uses `self.default_embodiment_id`. |
|
|
durations (Union[List[float], np.ndarray, torch.Tensor, None], optional): |
|
|
Duration of each action sequence in seconds. Shape: (b,). |
|
|
If None, the duration is inferred from the default values in `embodiment_config`. |
|
|
Defaults to None. |
|
|
**kwargs: Additional keyword arguments (currently unused, reserved for future use). |
|
|
|
|
|
Returns: |
|
|
Tuple[np.ndarray, np.ndarray]: A tuple containing: |
|
|
- reconstructed_actions: Reconstructed action sequences. |
|
|
Shape: (b, seq_len, max_action_dim). |
|
|
- padding_mask: Padding mask indicating valid timesteps. |
|
|
Shape: (b, seq_len), where True indicates valid timesteps. |
|
|
|
|
|
Raises: |
|
|
ValueError: If token sequence length is invalid or incompatible with the model configuration. |
|
|
TypeError: If input types are not supported. |
|
|
|
|
|
Examples: |
|
|
>>> # Using list of lists |
|
|
>>> tokens = [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]] |
|
|
>>> actions, mask = model.decode(tokens, embodiment_ids=[0, 0]) |
|
|
>>> # Using numpy array |
|
|
>>> tokens_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) |
|
|
>>> actions, mask = model.decode(tokens_np, embodiment_ids=[0, 0]) |
|
|
>>> # Using torch tensor |
|
|
>>> tokens_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) |
|
|
>>> actions, mask = model.decode(tokens_tensor, embodiment_ids=[0, 0]) |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
|
|
|
if isinstance(tokens, list): |
|
|
if not all(isinstance(seq, list) for seq in tokens): |
|
|
raise TypeError("If tokens is a list, all elements must be lists") |
|
|
if len(tokens) == 0: |
|
|
raise ValueError("Tokens list cannot be empty") |
|
|
if not all(isinstance(val, (int, np.integer)) for seq in tokens for val in seq): |
|
|
raise TypeError("All token values must be integers") |
|
|
tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device) |
|
|
elif isinstance(tokens, np.ndarray): |
|
|
if tokens.ndim != 2: |
|
|
raise ValueError( |
|
|
f"Expected 2D array (batch, n_tokens), got {tokens.ndim}D array with shape {tokens.shape}" |
|
|
) |
|
|
if not np.issubdtype(tokens.dtype, np.integer): |
|
|
raise TypeError(f"Tokens array must have integer dtype, got {tokens.dtype}") |
|
|
tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device) |
|
|
elif isinstance(tokens, torch.Tensor): |
|
|
if tokens.ndim != 2: |
|
|
raise ValueError( |
|
|
f"Expected 2D tensor (batch, n_tokens), got {tokens.ndim}D tensor with shape {tokens.shape}" |
|
|
) |
|
|
if not tokens.dtype.is_integer: |
|
|
raise TypeError(f"Tokens tensor must have integer dtype, got {tokens.dtype}") |
|
|
tokens_tensor = tokens.to(dtype=torch.long, device=self.device) |
|
|
else: |
|
|
raise TypeError(f"tokens must be List[List[int]], np.ndarray, or torch.Tensor, got {type(tokens)}") |
|
|
|
|
|
batch_size, n_tokens = tokens_tensor.shape |
|
|
if batch_size == 0: |
|
|
raise ValueError("Batch size must be at least 1") |
|
|
if n_tokens == 0: |
|
|
raise ValueError("Token sequence length must be at least 1") |
|
|
|
|
|
|
|
|
if n_tokens % self.n_tokens_per_quantizer != 0: |
|
|
raise ValueError( |
|
|
f"Token sequence length ({n_tokens}) must be divisible by tokens per quantizer " |
|
|
f"({self.n_tokens_per_quantizer}). Total tokens: {n_tokens}, " |
|
|
f"Expected multiple of: {self.n_tokens_per_quantizer}. " |
|
|
f"Number of quantizers: {self.num_quantizers}, Total tokens per sequence: {self.config.n_tokens}" |
|
|
) |
|
|
|
|
|
|
|
|
if tokens_tensor.min() < 0 or tokens_tensor.max() >= self.vocab_size: |
|
|
raise ValueError( |
|
|
f"Token values must be in range [0, {self.vocab_size}), " |
|
|
f"got range [{tokens_tensor.min().item()}, {tokens_tensor.max().item()}]" |
|
|
) |
|
|
|
|
|
|
|
|
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
|
|
if isinstance(embodiment_ids, int): |
|
|
if not 0 <= embodiment_ids < len(self.config.embodiment_config): |
|
|
raise ValueError( |
|
|
f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). " |
|
|
f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}" |
|
|
) |
|
|
embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device) |
|
|
elif isinstance(embodiment_ids, list): |
|
|
if len(embodiment_ids) != batch_size: |
|
|
raise ValueError( |
|
|
f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})" |
|
|
) |
|
|
for eid in embodiment_ids: |
|
|
if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config): |
|
|
raise ValueError( |
|
|
f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})" |
|
|
) |
|
|
embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
|
|
else: |
|
|
raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}") |
|
|
|
|
|
|
|
|
durations_tensor = None |
|
|
if durations is not None: |
|
|
if isinstance(durations, (list, np.ndarray)): |
|
|
durations_tensor = torch.tensor(durations, dtype=torch.float32, device=self.device) |
|
|
elif isinstance(durations, torch.Tensor): |
|
|
durations_tensor = durations.to(dtype=torch.float32, device=self.device) |
|
|
else: |
|
|
raise TypeError( |
|
|
f"durations must be List[float], np.ndarray, torch.Tensor, or None, got {type(durations)}" |
|
|
) |
|
|
if durations_tensor.ndim != 1: |
|
|
raise ValueError( |
|
|
f"durations must be 1D, got {durations_tensor.ndim}D with shape {durations_tensor.shape}" |
|
|
) |
|
|
if len(durations_tensor) != batch_size: |
|
|
raise ValueError(f"Length of durations ({len(durations_tensor)}) must match batch size ({batch_size})") |
|
|
if (durations_tensor <= 0).any(): |
|
|
raise ValueError("All durations must be positive") |
|
|
|
|
|
|
|
|
indices = einops.rearrange(tokens_tensor, "b (n m) -> b m n", m=self.n_tokens_per_quantizer) |
|
|
|
|
|
with torch.no_grad(): |
|
|
z_q = self._dequantize(indices) |
|
|
x_recon, padding_mask = self._decode(z_q, embodiment_ids_tensor, durations_tensor) |
|
|
|
|
|
return x_recon.float().cpu().numpy(), padding_mask.float().cpu().numpy() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Union[torch.Tensor, np.ndarray], |
|
|
embodiment_ids: Union[torch.Tensor, int, List[int], None] = None, |
|
|
padding_mask: Union[torch.Tensor, List[bool], np.ndarray, None] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Forward pass through the full ActionCodec pipeline. |
|
|
|
|
|
This method performs encoding, quantization, and decoding in a single forward pass. |
|
|
It is primarily used during training to compute reconstruction loss and commitment loss. |
|
|
Both numpy arrays and torch tensors are supported as input. |
|
|
|
|
|
Args: |
|
|
x (Union[torch.Tensor, np.ndarray]): Action sequences to process. |
|
|
Shape: (b, seq_len, max_action_dim). |
|
|
embodiment_ids (Union[torch.Tensor, int, List[int], None], optional): |
|
|
Embodiment IDs. Shape: (b,) if tensor or list. If int, same ID for all sequences. |
|
|
Defaults to None, which uses `self.default_embodiment_id`. |
|
|
padding_mask (Union[torch.Tensor, List[bool], np.ndarray, None], optional): |
|
|
Padding mask. Shape: (b, seq_len). Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
|
|
- x_recon (torch.Tensor): Reconstructed action sequences. |
|
|
Shape: (b, seq_len, max_action_dim). |
|
|
- recon_mask (torch.Tensor): Reconstruction mask indicating valid timesteps. |
|
|
Shape: (b, seq_len), where True indicates valid timesteps. |
|
|
|
|
|
Note: |
|
|
- For inference use cases, prefer using `encode()` and `decode()` methods separately. |
|
|
- If you need token indices, use the `encode()` method instead. |
|
|
""" |
|
|
|
|
|
if isinstance(x, np.ndarray): |
|
|
x = torch.tensor(x, dtype=self.dtype, device=self.device) |
|
|
|
|
|
|
|
|
if isinstance(embodiment_ids, list): |
|
|
embodiment_ids = torch.tensor(embodiment_ids, device=x.device, dtype=torch.long) |
|
|
elif isinstance(embodiment_ids, int): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if isinstance(padding_mask, (list, np.ndarray)): |
|
|
padding_mask = torch.tensor(padding_mask, device=x.device, dtype=torch.bool) |
|
|
|
|
|
|
|
|
z_e = self._encode(x, embodiment_ids, padding_mask) |
|
|
z_q, indices, perplexity, commit_loss = self._quantize(z_e, return_perplexity=True) |
|
|
x_recon, recon_mask = self._decode(z_q, embodiment_ids) |
|
|
|
|
|
return x_recon, recon_mask |
|
|
|
|
|
|
|
|
AutoModel.register(ActionCodecConfig, ActionCodec) |
|
|
|
|
|
__all__ = ["ActionCodec"] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=== ActionCodec Comprehensive Test ===\n") |
|
|
|
|
|
|
|
|
initial_config = { |
|
|
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"}, |
|
|
} |
|
|
|
|
|
|
|
|
config = ActionCodecConfig( |
|
|
embodiment_config=initial_config, |
|
|
n_tokens=16, |
|
|
n_quantizers=4, |
|
|
vq_type="rvq", |
|
|
vq_codebook_size=256, |
|
|
encoder_dim=128, |
|
|
decoder_dim=128, |
|
|
) |
|
|
|
|
|
|
|
|
latent_seq_len = int(config.n_tokens // config.n_quantizers) |
|
|
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.") |
|
|
|
|
|
codec = ActionCodec(config) |
|
|
codec.eval() |
|
|
|
|
|
|
|
|
print("\n--- Test 1: Basic Encode/Decode ---") |
|
|
batch_size = 2 |
|
|
seq_len_A = 10 |
|
|
|
|
|
|
|
|
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32) |
|
|
|
|
|
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool) |
|
|
padding_mask[1, 5:] = False |
|
|
|
|
|
embodiment_ids = [0, 0] |
|
|
|
|
|
|
|
|
codes = codec.encode(x, embodiment_ids, padding_mask) |
|
|
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}") |
|
|
|
|
|
|
|
|
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}" |
|
|
|
|
|
|
|
|
x_recon, recon_mask = codec.decode(codes, embodiment_ids) |
|
|
print(f"Reconstructed shape: {x_recon.shape}") |
|
|
print(f"Recon mask shape: {recon_mask.shape}") |
|
|
|
|
|
assert x_recon.shape == (batch_size, seq_len_A, 7) |
|
|
|
|
|
|
|
|
print("\n--- Test 2: Dynamic Expansion ---") |
|
|
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}} |
|
|
|
|
|
print("Expanding codec to include Robot B (10 dims, 20Hz)...") |
|
|
codec.expand_embodiment(new_robot_config) |
|
|
|
|
|
assert codec.encoder.max_action_dim == 10 |
|
|
assert codec.decoder.max_action_dim == 10 |
|
|
print("✅ Expansion successful.") |
|
|
|
|
|
|
|
|
print("\n--- Test 3: Mixed Batch Inference ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32) |
|
|
|
|
|
|
|
|
data_A = np.random.randn(10, 7) |
|
|
batch_x_mixed[0, :10, :7] = data_A |
|
|
|
|
|
|
|
|
data_B = np.random.randn(20, 10) |
|
|
batch_x_mixed[1, :20, :10] = data_B |
|
|
|
|
|
|
|
|
|
|
|
mixed_ids = [0, 1] |
|
|
|
|
|
|
|
|
mixed_mask = np.zeros((2, 20), dtype=bool) |
|
|
mixed_mask[0, :10] = True |
|
|
mixed_mask[1, :20] = True |
|
|
|
|
|
print("Encoding mixed batch...") |
|
|
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask) |
|
|
|
|
|
print("Decoding mixed batch...") |
|
|
|
|
|
durations = [1, 1] |
|
|
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations) |
|
|
|
|
|
print(f"Mixed Recon Shape: {x_recon_mixed.shape}") |
|
|
|
|
|
|
|
|
|
|
|
valid_A = dec_mask_mixed[0].sum() |
|
|
valid_B = dec_mask_mixed[1].sum() |
|
|
|
|
|
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}") |
|
|
|
|
|
assert valid_A == 10 |
|
|
assert valid_B == 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("✅ Mixed batch processed successfully.") |
|
|
|
|
|
print("\n✨ All systems go.") |
|
|
|