Instructions to use ZibinDong/ActionCodec-Base-RVQft with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ZibinDong/ActionCodec-Base-RVQft with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ZibinDong/ActionCodec-Base-RVQft", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ZibinDong/ActionCodec-Base-RVQft", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from copy import deepcopy | |
| from typing import List, Literal, Optional, Tuple, Union | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .configuration_actioncodec import ActionCodecConfig | |
| def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: | |
| original_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| sin = sin.to(torch.float32) | |
| cos = cos.to(torch.float32) | |
| x1 = x[..., 0::2] | |
| x2 = x[..., 1::2] | |
| rotated_x1 = x1 * cos - x2 * sin | |
| rotated_x2 = x1 * sin + x2 * cos | |
| x_out = torch.empty_like(x) | |
| x_out[..., 0::2] = rotated_x1 | |
| x_out[..., 1::2] = rotated_x2 | |
| return x_out.to(original_dtype) | |
| def attention_op( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| is_causal: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| q (torch.Tensor): (*b, h, l, d) | |
| k (torch.Tensor): (*b, k, s, d) | |
| v (torch.Tensor): (*b, k, s, d) | |
| mask (torch.Tensor | None, optional): (*b, l, s), where `True` indicates the element should take part in attention. Defaults to None. | |
| is_causal (bool, optional): Whether to apply causal mask. Defaults to False. | |
| Returns: | |
| torch.Tensor: (*b, h, l, d) | |
| """ | |
| heads, kv_heads = q.shape[-3], k.shape[-3] | |
| if heads != kv_heads: | |
| assert heads % kv_heads == 0, f"q_heads must be divisible by kv_heads, but got {heads} and {kv_heads}" | |
| heads_per_kv_head = heads // kv_heads | |
| k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) | |
| if mask is not None: | |
| if mask.dim() == 3: | |
| mask = mask.unsqueeze(1) | |
| mask = mask.expand(mask.shape[0], heads, -1, -1) | |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=is_causal) | |
| return out | |
| class L2Norm(nn.Module): | |
| def forward(self, x: torch.Tensor): | |
| return F.normalize(x, p=2, dim=-1) | |
| class Attention(nn.Module): | |
| """ | |
| Args: | |
| hidden_size (int): Hidden size of the input tensor. | |
| num_heads (int): Number of attention heads. | |
| num_kv_heads (int, optional): Number of key/value heads. Defaults to None. | |
| qk_norm (Literal["l2", "ln", "none"], optional): Type of normalization to apply to query/key. Defaults to "none". | |
| bias (bool, optional): Whether to use bias in linear layers. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| num_kv_heads: int | None = None, | |
| qk_norm: Literal["l2", "ln", "none"] = "none", | |
| bias: bool = False, | |
| zero_init_output: bool = False, | |
| ): | |
| super().__init__() | |
| num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | |
| self.dim = hidden_size // num_heads | |
| self.num_heads, self.num_kv_heads = num_heads, num_kv_heads | |
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) | |
| self.k_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias) | |
| self.v_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias) | |
| self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias) | |
| if qk_norm == "l2": | |
| self.q_norm = L2Norm() | |
| self.k_norm = L2Norm() | |
| elif qk_norm == "ln": | |
| self.q_norm = nn.LayerNorm(self.dim, elementwise_affine=False) | |
| self.k_norm = nn.LayerNorm(self.dim, elementwise_affine=False) | |
| else: | |
| self.q_norm = nn.Identity() | |
| self.k_norm = nn.Identity() | |
| if zero_init_output: | |
| nn.init.zeros_(self.out_proj.weight) | |
| if self.out_proj.bias is not None: | |
| nn.init.zeros_(self.out_proj.bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| context: torch.Tensor | None = None, | |
| mask: torch.Tensor | None = None, | |
| rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, | |
| is_causal: bool = False, | |
| ) -> torch.Tensor: | |
| context = x if context is None else context | |
| q = self.q_proj(x) | |
| k, v = self.k_proj(context), self.v_proj(context) | |
| q = einops.rearrange(q, "b l (h d) -> b h l d", h=self.num_heads) | |
| k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_kv_heads) | |
| v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_kv_heads) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| if rotary_pos_emb is not None: | |
| q, k = map(lambda t: apply_rotary_pos_emb(t, *rotary_pos_emb), (q, k)) | |
| out = attention_op(q, k, v, mask=mask, is_causal=is_causal) | |
| out = einops.rearrange(out, "b h l d -> b l (h d)") | |
| out = self.out_proj(out) | |
| return out | |
| class PositionalEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| encoding_type: Literal["sincos", "fourier"] = "sincos", | |
| scale: float = 2.0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.encoding_type = encoding_type | |
| if encoding_type == "fourier": | |
| self.register_buffer("freqs", torch.randn(dim // 2) * scale, persistent=True) | |
| elif encoding_type == "sincos": | |
| pass | |
| else: | |
| raise ValueError(f"encoding_type must be 'sincos' or 'fourier', but got {encoding_type}") | |
| def _create_sincos_emb(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| position = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * -(math.log(10000.0) / self.dim) | |
| ) | |
| pos_emb = torch.zeros(seq_len, self.dim, device=device, dtype=dtype) | |
| pos_emb[:, 0::2] = torch.sin(position * div_term).to(dtype) | |
| pos_emb[:, 1::2] = torch.cos(position * div_term).to(dtype) | |
| return pos_emb | |
| def _create_fourier_emb(self, timestamps: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| pos_emb = torch.einsum("b t, d -> b t d", timestamps, 2 * np.pi * self.freqs).to(device, torch.float32) | |
| pos_emb = torch.cat([pos_emb.cos(), pos_emb.sin()], dim=-1).to(dtype) | |
| return pos_emb | |
| def forward( | |
| self, x: torch.Tensor, freq: Optional[Union[float, torch.Tensor]] = None, dtype: torch.dtype = torch.float32 | |
| ) -> torch.Tensor: | |
| b, t = x.shape[0], x.shape[1] | |
| device = x.device | |
| if self.encoding_type == "sincos": | |
| pos_emb = self._create_sincos_emb(t, device, dtype) | |
| pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1) | |
| return pos_emb * 0.1 | |
| elif self.encoding_type == "fourier": | |
| if freq is None: | |
| raise ValueError( | |
| "freq must be provided when encoding_type is 'fourier'. Please provide the sequence frequency." | |
| ) | |
| if isinstance(freq, float): | |
| freq = torch.tensor(freq, dtype=dtype, device=device)[None].expand(b) | |
| timestamps = torch.einsum("t, b -> b t", torch.arange(t, dtype=dtype, device=device), 1 / freq) | |
| pos_emb = self._create_fourier_emb(timestamps, device, dtype) | |
| return pos_emb * 0.1 | |
| else: | |
| raise ValueError(f"Unknown encoding_type: {self.encoding_type}") | |
| class SinusoidalPositionalEmbedding(PositionalEmbedding): | |
| def __init__(self, dim: int): | |
| super().__init__(dim=dim, encoding_type="sincos") | |
| def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| return super().forward(x, freq=None) | |
| class FeedForward(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False): | |
| super().__init__() | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) | |
| self.act_fn = nn.GELU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| down_proj = self.down_proj(self.act_fn(self.up_proj(x))) | |
| return down_proj | |
| class LayerScale(nn.Module): | |
| def __init__(self, dim, init_val=1e-2): | |
| super().__init__() | |
| self.scale = nn.Parameter(torch.full([dim], init_val)) | |
| def forward(self, x): | |
| return x * self.scale | |
| class PerceiverTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| mlp_ratio: int = 4, | |
| dropout: float = 0.0, | |
| qk_norm: str = "ln", | |
| layer_scale: bool = True, | |
| zero_init_output: bool = False, | |
| add_self_attn: bool = False, | |
| add_causal_mask: bool = False, | |
| ): | |
| super().__init__() | |
| self.add_self_attn = add_self_attn | |
| self.add_causal_mask = add_causal_mask | |
| self.norm1 = nn.LayerNorm(dim, eps=1e-2) | |
| self.cross_attn = Attention( | |
| hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output | |
| ) | |
| if add_self_attn: | |
| self.norm_self_attn = nn.LayerNorm(dim, eps=1e-2) | |
| self.self_attn = Attention( | |
| hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output | |
| ) | |
| else: | |
| self.self_attn = None | |
| self.norm2 = nn.LayerNorm(dim, eps=1e-2) | |
| self.mlp = FeedForward(hidden_size=dim, intermediate_size=int(mlp_ratio * dim), bias=True) | |
| self.dropout = nn.Dropout(dropout) | |
| self.attn_scale = LayerScale(dim) if layer_scale else nn.Identity() | |
| self.mlp_scale = LayerScale(dim) if layer_scale else nn.Identity() | |
| if zero_init_output: | |
| nn.init.zeros_(self.mlp.down_proj.weight) | |
| if self.mlp.down_proj.bias is not None: | |
| nn.init.zeros_(self.mlp.down_proj.bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| context: torch.Tensor, | |
| context_mask: Optional[torch.Tensor] = None, | |
| rotary_pos_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| ) -> torch.Tensor: | |
| residual = x | |
| x = self.norm1(x) | |
| x = self.cross_attn(x=x, context=context, mask=context_mask, rotary_pos_emb=rotary_pos_emb, is_causal=False) | |
| x = self.dropout(x) | |
| x = self.attn_scale(x) | |
| x = x + residual | |
| if self.add_self_attn: | |
| residual = x | |
| x = self.norm_self_attn(x) | |
| x = self.self_attn( | |
| x=x, | |
| context=None, | |
| mask=None, | |
| rotary_pos_emb=rotary_pos_emb, | |
| is_causal=self.add_causal_mask, | |
| ) | |
| x = self.dropout(x) | |
| x = self.attn_scale(x) | |
| x = x + residual | |
| residual = x | |
| x = self.norm2(x) | |
| x = self.mlp(x) | |
| x = self.dropout(x) | |
| x = self.mlp_scale(x) | |
| x = x + residual | |
| return x | |
| class EmbodimentEmbedding(nn.Module): | |
| def __init__(self, embodiment_config: dict, out_len: int, out_dim: int) -> None: | |
| super().__init__() | |
| self.out_len, self.out_dim = out_len, out_dim | |
| self.embodiment_config = embodiment_config | |
| self.num_embodiments = len(self.embodiment_config) | |
| self.embedding = nn.Embedding(self.num_embodiments, out_dim * out_len) | |
| def expand_embodiment(self, embodiment_config: dict): | |
| for k in embodiment_config.keys(): | |
| assert k not in self.embodiment_config.keys() | |
| self.embodiment_config.update(embodiment_config) | |
| self.num_embodiments = len(self.embodiment_config) | |
| extra_embodiments = len(embodiment_config) | |
| old_weights = torch.clone(self.embedding.weight) | |
| self.embedding = nn.Embedding(self.num_embodiments, self.out_dim * self.out_len) | |
| self.embedding.weight.data[:-extra_embodiments] = old_weights | |
| return self | |
| def keys(self) -> list[str]: | |
| return list(self.embodiment_config.keys()) | |
| def ids_to_keys(self, ids: torch.Tensor) -> List[str]: | |
| return [self.keys()[i] for i in ids] | |
| def keys_to_ids(self, keys: List[str]) -> torch.Tensor: | |
| return torch.tensor([self.keys().index(k) for k in keys]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return einops.rearrange(self.embedding(x), "b (l d) -> b l d", d=self.out_dim) | |
| class PerceiverEncoder(nn.Module): | |
| def __init__(self, config: ActionCodecConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embodiment_config = deepcopy(config.embodiment_config) | |
| out_len = int(config.n_tokens // config.n_quantizers) | |
| dim = config.encoder_dim | |
| _action_dim, _freq, _duration = list(), list(), list() | |
| for k, v in self.embodiment_config.items(): | |
| _action_dim.append(v["action_dim"]) | |
| _freq.append(v["freq"]) | |
| _duration.append(v["duration"]) | |
| self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False) | |
| self.register_buffer("_freq", torch.tensor(_freq), persistent=False) | |
| self.register_buffer("_duration", torch.tensor(_duration), persistent=False) | |
| self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values()) | |
| self.input_proj = nn.Linear(self.max_action_dim, dim) | |
| self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, out_len, dim) | |
| self.pos_emb_q = PositionalEmbedding(dim, encoding_type="sincos") | |
| self.pos_emb_kv = PositionalEmbedding(dim, encoding_type=config.encoder_pos_encoding_type) | |
| self.layers = nn.ModuleList( | |
| [ | |
| PerceiverTransformerBlock( | |
| dim=dim, | |
| num_heads=config.encoder_n_heads, | |
| add_self_attn=config.encoder_add_self_attn, | |
| add_causal_mask=config.encoder_add_causal_mask, | |
| ) | |
| for _ in range(config.encoder_n_layers) | |
| ] | |
| ) | |
| self.output_proj = nn.Linear(dim, config.z_dim) | |
| self._init_weights() | |
| def _init_weights(self): | |
| nn.init.trunc_normal_(self.input_proj.weight, std=0.02) | |
| if self.input_proj.bias is not None: | |
| nn.init.zeros_(self.input_proj.bias) | |
| nn.init.trunc_normal_(self.output_proj.weight, std=0.02) | |
| if self.output_proj.bias is not None: | |
| nn.init.zeros_(self.output_proj.bias) | |
| nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02) | |
| def expand_embodiment(self, embodiment_config: dict): | |
| self.cls_tokens.expand_embodiment(embodiment_config) | |
| self.embodiment_config = self.cls_tokens.embodiment_config | |
| _action_dim, _freq, _duration = list(), list(), list() | |
| for k, v in self.embodiment_config.items(): | |
| _action_dim.append(v["action_dim"]) | |
| _freq.append(v["freq"]) | |
| _duration.append(v["duration"]) | |
| self._action_dim = torch.tensor(_action_dim) | |
| self._freq = torch.tensor(_freq) | |
| self._duration = torch.tensor(_duration) | |
| max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values()) | |
| if max_action_dim > self.max_action_dim: | |
| old_weights = torch.clone(self.input_proj.weight) | |
| old_bias = torch.clone(self.input_proj.bias) | |
| self.input_proj = nn.Linear(max_action_dim, self.config.encoder_dim) | |
| self.input_proj.weight.data[:, : self.max_action_dim] = old_weights | |
| self.input_proj.bias.data = old_bias | |
| self.max_action_dim = max_action_dim | |
| return self | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| embodiment_ids: torch.Tensor | int, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Encode action sequences into latent representations. | |
| Args: | |
| x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim). | |
| Assumes that the action dimension is zero-padded to the max action dimension. | |
| `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. | |
| embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). | |
| If int, the same embodiment ID is repeated for all sequences in the batch. | |
| It specifies the embodiment to encode. | |
| padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. | |
| It is used to mask the padding tokens on `seq_len` dimension. | |
| Returns: | |
| torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim). | |
| """ | |
| b, seq_len, _ = x.shape | |
| x = self.input_proj(x) | |
| if isinstance(embodiment_ids, int): | |
| embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b) | |
| cls_tokens = self.cls_tokens(embodiment_ids) | |
| freqs = self._freq[embodiment_ids].to(x.device, x.dtype) | |
| pos_emb_q = self.pos_emb_q(cls_tokens) | |
| pos_emb_kv = self.pos_emb_kv(x, freqs) | |
| cls_tokens = cls_tokens + pos_emb_q | |
| x = x + pos_emb_kv | |
| if padding_mask is not None: | |
| padding_mask = padding_mask.unsqueeze(1).expand(-1, cls_tokens.shape[1], -1) | |
| for layer in self.layers: | |
| cls_tokens = layer(x=cls_tokens, context=x, context_mask=padding_mask) | |
| return self.output_proj(cls_tokens) | |
| class PerceiverDecoder(nn.Module): | |
| def __init__(self, config: ActionCodecConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embodiment_config = deepcopy(config.embodiment_config) | |
| dim = config.decoder_dim | |
| _action_dim, _freq, _duration = list(), list(), list() | |
| for k, v in self.embodiment_config.items(): | |
| _action_dim.append(v["action_dim"]) | |
| _freq.append(v["freq"]) | |
| _duration.append(v["duration"]) | |
| self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False) | |
| self.register_buffer("_freq", torch.tensor(_freq), persistent=False) | |
| self.register_buffer("_duration", torch.tensor(_duration), persistent=False) | |
| self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values()) | |
| self.input_proj = nn.Linear(config.z_dim, dim) | |
| self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, config.decoder_cls_size, dim) | |
| self.pos_emb_q = PositionalEmbedding(dim, encoding_type=config.decoder_pos_encoding_type) | |
| self.pos_emb_kv = PositionalEmbedding(dim, encoding_type="sincos") | |
| self.layers = nn.ModuleList( | |
| [ | |
| PerceiverTransformerBlock( | |
| dim=dim, | |
| num_heads=config.decoder_n_heads, | |
| add_self_attn=config.decoder_add_self_attn, | |
| add_causal_mask=config.decoder_add_causal_mask, | |
| ) | |
| for _ in range(config.decoder_n_layers) | |
| ] | |
| ) | |
| self.output_proj = nn.Linear(dim, self.max_action_dim) | |
| self._init_weights() | |
| def _init_weights(self): | |
| nn.init.trunc_normal_(self.input_proj.weight, std=0.02) | |
| if self.input_proj.bias is not None: | |
| nn.init.zeros_(self.input_proj.bias) | |
| nn.init.trunc_normal_(self.output_proj.weight, std=0.02) | |
| if self.output_proj.bias is not None: | |
| nn.init.zeros_(self.output_proj.bias) | |
| nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02) | |
| def expand_embodiment(self, embodiment_config: dict): | |
| self.cls_tokens.expand_embodiment(embodiment_config) | |
| self.embodiment_config = self.cls_tokens.embodiment_config | |
| _action_dim, _freq, _duration = list(), list(), list() | |
| for k, v in self.embodiment_config.items(): | |
| _action_dim.append(v["action_dim"]) | |
| _freq.append(v["freq"]) | |
| _duration.append(v["duration"]) | |
| self._action_dim = torch.tensor(_action_dim) | |
| self._freq = torch.tensor(_freq) | |
| self._duration = torch.tensor(_duration) | |
| max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values()) | |
| if max_action_dim > self.max_action_dim: | |
| old_weights = torch.clone(self.output_proj.weight) | |
| old_bias = torch.clone(self.output_proj.bias) | |
| self.output_proj = nn.Linear(self.config.decoder_dim, max_action_dim) | |
| self.output_proj.weight.data[: self.max_action_dim, :] = old_weights | |
| self.output_proj.bias.data[: self.max_action_dim] = old_bias | |
| self.max_action_dim = max_action_dim | |
| return self | |
| def forward( | |
| self, x: torch.Tensor, embodiment_ids: torch.Tensor | int, durations: torch.Tensor | None = None | |
| ) -> torch.Tensor: | |
| """Decode latent representations into action sequences. | |
| Args: | |
| x (torch.Tensor): Latent representations to decode. Shape: (b, n_tokens_per_quantizer, z_dim). | |
| embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). | |
| If int, the same embodiment ID is repeated for all sequences in the batch. | |
| It specifies the embodiment to decode. | |
| durations (torch.Tensor | None, optional): Duration of each action sequence. Shape: (b,). | |
| If `None`, the duration is inferred from the default values in `embodiment_config`. | |
| Returns: | |
| torch.Tensor: Decoded action sequences. Shape: (b, seq_len, max_action_dim). | |
| Assumes that the action dimension is zero-padded to the max action dimension. | |
| `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. | |
| """ | |
| b, seq_len, _ = x.shape | |
| x = self.input_proj(x) | |
| if isinstance(embodiment_ids, int): | |
| embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b) | |
| cls_tokens = self.cls_tokens(embodiment_ids) | |
| freqs = self._freq[embodiment_ids] | |
| durations = self._duration[embodiment_ids] if durations is None else durations | |
| action_horizons = (durations * freqs).long() | |
| max_horizon = action_horizons.max().item() | |
| padding_mask = torch.arange(max_horizon, device=x.device).expand(b, -1) < action_horizons.unsqueeze(1) | |
| if self.config.decoder_cls_size == 1: | |
| cls_tokens = cls_tokens.repeat(1, max_horizon, 1) | |
| pos_emb_q = self.pos_emb_q(cls_tokens, freqs) | |
| pos_emb_kv = self.pos_emb_kv(x) | |
| cls_tokens = cls_tokens + pos_emb_q | |
| x = x + pos_emb_kv | |
| for layer in self.layers: | |
| cls_tokens = layer(x=cls_tokens, context=x) | |
| output = self.output_proj(cls_tokens) | |
| return output, padding_mask | |
| if __name__ == "__main__": | |
| # ------------------------------------------ | |
| # 1. Initialization | |
| # ------------------------------------------ | |
| print("=== Test 1: Initialization ===") | |
| # Define initial config with two smaller robots | |
| initial_embodiment_config = { | |
| "robot_small_7d": {"action_dim": 7, "freq": 20, "duration": 1, "description": "Original Robot"}, | |
| "robot_tiny_3d": {"action_dim": 3, "freq": 10, "duration": 2, "description": "Tiny Robot"}, | |
| } | |
| config = ActionCodecConfig(embodiment_config=initial_embodiment_config) | |
| # Set seed for reproducibility | |
| torch.manual_seed(42) | |
| encoder = PerceiverEncoder(config) | |
| decoder = PerceiverDecoder(config) | |
| encoder.eval() | |
| decoder.eval() | |
| print("β Models initialized successfully.") | |
| # ------------------------------------------ | |
| # 2. Baseline Inference (Before Expansion) | |
| # ------------------------------------------ | |
| print("\n=== Test 2: Baseline Inference (Before Expansion) ===") | |
| # Simulate Robot 1 (7-dim) data | |
| # Max action dim currently is 7. | |
| batch_size = 1 | |
| seq_len = 20 # 20Hz * 1s | |
| # Input: (1, 20, 7) | |
| input_action_v0 = torch.randn(batch_size, seq_len, 7) | |
| emb_id_v0 = torch.tensor([0], dtype=torch.long) # ID 0 -> robot_small_7d | |
| with torch.no_grad(): | |
| z_ref = encoder(input_action_v0, emb_id_v0) | |
| rec_action_ref, _ = decoder(z_ref, emb_id_v0) | |
| print(f"Reference Latent Shape: {z_ref.shape}") | |
| print(f"Reference Recon Shape: {rec_action_ref.shape}") | |
| # ------------------------------------------ | |
| # 3. Model Expansion (Add New Embodiment) | |
| # ------------------------------------------ | |
| print("\n=== Test 3: Model Expansion ===") | |
| # Add a larger robot: 10-dim, high frequency | |
| new_embodiment_config = { | |
| "robot_large_10d": {"action_dim": 10, "freq": 30, "duration": 1, "description": "New Large Robot"} | |
| } | |
| print(f"Expanding from Max Dim {encoder.max_action_dim} to 10...") | |
| encoder.expand_embodiment(new_embodiment_config) | |
| decoder.expand_embodiment(new_embodiment_config) | |
| # Verify buffer updates | |
| assert encoder._action_dim[-1] == 10 | |
| assert encoder.max_action_dim == 10 | |
| assert decoder.max_action_dim == 10 | |
| print(f"β Expansion successful. New Encoder Input Dim: {encoder.input_proj.weight.shape[1]}") | |
| print(f"β New Decoder Output Dim: {decoder.output_proj.weight.shape[0]}") | |
| # ------------------------------------------ | |
| # 4. Encoder Invariance Check | |
| # ------------------------------------------ | |
| print("\n=== Test 4: Encoder Invariance Check ===") | |
| # Pad old data (7 dims) to new max dim (10 dims) with ZEROS. | |
| input_action_padded = torch.zeros(batch_size, seq_len, 10) | |
| input_action_padded[:, :, :7] = input_action_v0 | |
| with torch.no_grad(): | |
| z_new = encoder(input_action_padded, emb_id_v0) | |
| # Compare latents | |
| diff_z = (z_ref - z_new).abs().max().item() | |
| print(f"Latent Difference (Max Abs): {diff_z:.8f}") | |
| if diff_z < 1e-6: | |
| print("β PASS: Encoder produces identical latents for old data.") | |
| else: | |
| print("β FAIL: Encoder outputs changed after expansion!") | |
| # ------------------------------------------ | |
| # 5. Decoder Invariance Check | |
| # ------------------------------------------ | |
| print("\n=== Test 5: Decoder Invariance Check ===") | |
| with torch.no_grad(): | |
| # Feed old latent to expanded decoder | |
| rec_action_new_full, _ = decoder(z_ref, emb_id_v0) | |
| # Output shape should be (1, 20, 10) | |
| print(f"Expanded Decoder Output Shape: {rec_action_new_full.shape}") | |
| # Slice first 7 dims, should match reference | |
| rec_action_new_sliced = rec_action_new_full[:, :, :7] | |
| diff_rec = (rec_action_ref - rec_action_new_sliced).abs().max().item() | |
| print(f"Reconstruction Difference (Max Abs on valid dims): {diff_rec:.8f}") | |
| if diff_rec < 1e-6: | |
| print("β PASS: Decoder produces identical action values for valid dimensions.") | |
| else: | |
| print("β FAIL: Decoder outputs changed!") | |
| # Check phantom dimensions (7-9) | |
| # For old embodiment, these are driven by random weights and should be random | |
| new_dims_mean = rec_action_new_full[:, :, 7:].abs().mean().item() | |
| print(f"Values in new phantom dimensions (should be random garbage): {new_dims_mean:.4f}") | |
| # ------------------------------------------ | |
| # 6. New Embodiment Inference | |
| # ------------------------------------------ | |
| print("\n=== Test 6: New Embodiment Inference ===") | |
| # ID 2 -> robot_large_10d | |
| emb_id_new = torch.tensor([2], dtype=torch.long) | |
| seq_len_new = 30 # 30Hz * 1s | |
| input_action_new = torch.randn(1, seq_len_new, 10) | |
| with torch.no_grad(): | |
| z_large = encoder(input_action_new, emb_id_new) | |
| rec_large, mask_large = decoder(z_large, emb_id_new) | |
| print(f"New Embodiment Output Shape: {rec_large.shape}") | |
| if rec_large.shape == (1, 30, 10): | |
| print("β PASS: New embodiment handled correctly with full dimensions.") | |
| else: | |
| print(f"β FAIL: Expected (1, 30, 10), got {rec_large.shape}") | |
| # ------------------------------------------ | |
| # 7. Mixed Batch Processing (Masking) | |
| # ------------------------------------------ | |
| print("\n=== Test 7: Mixed Batch Processing ===") | |
| # Batch size 2: [Robot 0 (20Hz, 7dim), Robot 2 (30Hz, 10dim)] | |
| mixed_emb_ids = torch.tensor([0, 2], dtype=torch.long) | |
| # Max seq len is 30. Max action dim is 10. | |
| batch_input = torch.zeros(2, 30, 10) | |
| # Fill data | |
| # Batch 0: Length 20, Dim 7 valid | |
| batch_input[0, :20, :7] = torch.randn(20, 7) | |
| # Batch 1: Length 30, Dim 10 valid | |
| batch_input[1, :30, :10] = torch.randn(30, 10) | |
| # Encoder Mask: True = Valid | |
| enc_padding_mask = torch.zeros(2, 30, dtype=torch.bool) | |
| enc_padding_mask[0, :20] = True | |
| enc_padding_mask[1, :30] = True | |
| print("Running mixed batch...") | |
| with torch.no_grad(): | |
| z_mixed = encoder(batch_input, mixed_emb_ids, padding_mask=enc_padding_mask) | |
| rec_mixed, dec_padding_mask = decoder(z_mixed, mixed_emb_ids) | |
| print(f"Mixed Reconstruction Shape: {rec_mixed.shape}") # Should be (2, 30, 10) | |
| # Verify Decoder Generated Mask | |
| valid_len_0 = dec_padding_mask[0].sum().item() | |
| valid_len_1 = dec_padding_mask[1].sum().item() | |
| print(f"Decoder Mask Valid Lengths: Batch 0={valid_len_0}, Batch 1={valid_len_1}") | |
| if valid_len_0 == 20 and valid_len_1 == 30: | |
| print("β PASS: Decoder correctly generated masks based on frequency and duration.") | |
| else: | |
| print("β FAIL: Decoder masks are incorrect.") | |
| print("\nβ¨ All Tests Completed β¨") | |