diff --git "a/prismatic/models/action_heads.py" "b/prismatic/models/action_heads.py" new file mode 100644--- /dev/null +++ "b/prismatic/models/action_heads.py" @@ -0,0 +1,2716 @@ +"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction. + +Supported Normalization Methods (optimized for dual-arm robots): +- per_limb_layernorm: Independent LayerNorm for each limb (recommended for dual-arm) +- per_limb_rmsnorm: Independent RMSNorm for each limb (recommended for dual-arm) +- rmsnorm: Root Mean Square Layer Normalization (preserves relative relationships) +- layerscale: Learnable layer scaling (lightweight, preserves coordination) +- l2: L2 Normalization (preserves relative magnitudes) +- none/identity: No normalization (prevents coordination disruption) +- layernorm: Standard Layer Normalization (may cause dual-arm jitter) +- scalenorm: Simplified LayerNorm variant with only scaling +""" + +import math + +import numpy as np +import torch +import torch.nn as nn +from torchvision.transforms.functional import gaussian_blur +# from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX , SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK +from prismatic.models.query_projection import Query2ActionAdapter +import torch.nn.functional as F +from einops import rearrange + + +class RMSNorm(nn.Module): + def __init__(self, d_model: int, eps: float = 1e-5): + super().__init__() + + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + return output + +class SinusoidalPositionalEncoding(nn.Module): + """ + Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps. + + For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,) + Then the output would be a batch of 32 timestep embeddings -> shape (32, D) + + Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim # dimensionality of the positional encoding + + def forward(self, x): + # x: (batch_size,) + device = x.device + assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}" + half_dim = self.dim // 2 + exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,) + emb = torch.exp(exponent) # shape: (D/2,) + emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D) + return emb + + +class MLPResNetBlock(nn.Module): + """One MLP ResNet block with a residual connection.""" + def __init__(self, dim): + super().__init__() + self.dim = dim + self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.ReLU(), + ) + + def forward(self, x): + # x: (batch_size, hidden_dim) + # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as + # described here: https://arxiv.org/pdf/2002.04745.pdf + identity = x + x = self.ffn(x) + x = x + identity + return x + + +class MLPResNet(nn.Module): + """MLP with residual connection blocks.""" + def __init__(self, num_blocks, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer_norm1 = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.mlp_resnet_blocks = nn.ModuleList() + for _ in range(num_blocks): + self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) + self.layer_norm2 = nn.LayerNorm(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + # x: (batch_size, input_dim) + x = self.layer_norm1(x) # shape: (batch_size, input_dim) + x = self.fc1(x) # shape: (batch_size, hidden_dim) + x = self.relu(x) # shape: (batch_size, hidden_dim) + for block in self.mlp_resnet_blocks: + x = block(x) # shape: (batch_size, hidden_dim) + x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) + x = self.fc2(x) # shape: (batch_size, output_dim) + return x + + +class L1RegressionActionHead(nn.Module): + """Simple MLP-based action head that generates continuous actions via L1 regression.""" + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + ): + super().__init__() + self.action_dim = action_dim + self.model = MLPResNet( + num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim + ) + + def predict_action(self, actions_hidden_states, num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, chunk_len * action_dim, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + batch_size = actions_hidden_states.shape[0] + device = actions_hidden_states.device + rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) + action = self.model(rearranged_actions_hidden_states) + return action + + + +class L1ActionProprioHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + ): + super().__init__() + self.action_dim = action_dim + self.cross_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, dropout=0.1,batch_first=True) + self.model = MLPResNet( + num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim + ) + + def predict_action(self, actions_hidden_states, proprio_hidden_states ): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, chunk_len * action_dim, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + batch_size = actions_hidden_states.shape[0] + device = actions_hidden_states.device + action_proprio_hidden_states = torch.cat([proprio_hidden_states,actions_hidden_states], dim=1) + fused_hidden_states = self.cross_attn(action_proprio_hidden_states,actions_hidden_states,actions_hidden_states)[0] + fused_hidden_states = fused_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK , -1) + action = self.model(fused_hidden_states) + return action + +class L1ProprioHead(nn.Module): + """Simple MLP-based action head that generates continuous actions via L1 regression.""" + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + proprio_dim=8, + ): + super().__init__() + self.proprio_dim = proprio_dim + self.model = NewMLPResNet( + num_blocks=4, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=proprio_dim * NUM_ACTIONS_CHUNK + ) + + def predict_proprio(self, proprio_hidden_states): + # proprios_hidden_states: last hidden states of Transformer corresponding to proprio tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, proprio_dim) + proprio_hidden_states = self.model(proprio_hidden_states) + proprio_hidden_states = proprio_hidden_states.reshape(proprio_hidden_states.shape[0], NUM_ACTIONS_CHUNK , -1) + return proprio_hidden_states + + +class NewMLPResNet(nn.Module): + """MLP with residual connection blocks.""" + def __init__(self, num_blocks, input_dim, hidden_dim, output_dim,drop_ratio=0.5): + super().__init__() + self.layer_norm1 = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.mlp_resnet_blocks = nn.ModuleList() + for _ in range(num_blocks): + self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) + self.layer_norm2 = nn.LayerNorm(hidden_dim) + self.dropout = nn.Dropout(drop_ratio) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + # x: (batch_size, input_dim) + x = self.layer_norm1(x) # shape: (batch_size, input_dim) + x = self.fc1(x) # shape: (batch_size, hidden_dim) + x = self.relu(x) # shape: (batch_size, hidden_dim) + for block in self.mlp_resnet_blocks: + x = block(x) # shape: (batch_size, hidden_dim) + x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) + x = self.fc2(self.dropout(x)) # shape: (batch_size, output_dim) + return x + +# class TSActionHead(nn.Module): +# def __init__( +# self, +# input_dim=4096, +# hidden_dim=4096, +# action_dim=7, +# ): +# super().__init__() +# self.action_dim = action_dim +# self.heads = NewMLPResNet( +# num_blocks=2, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=action_dim * NUM_ACTIONS_CHUNK +# ) +# def predict_action(self, actions_hidden_states): +# # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence +# # - shape: (batch_size, 1, hidden_dim) +# # ground_truth_actions: ground-truth actions +# # - shape: (batch_size, chunk_len, action_dim) +# actions = self.heads(actions_hidden_states) # (batch_size, 1, action_dim * NUM_ACTIONS_CHUNK) +# actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) +# return actions + + + +# class MultiScaleDecoder(nn.Module): +# def __init__(self, num_blocks, input_dim, hidden_dim, output_dims = [8, 16, 32, 64], drop_ratio=0.5): +# super().__init__() +# self.layer_norm1 = nn.LayerNorm(input_dim) +# self.fc1 = nn.Linear(input_dim, hidden_dim) +# self.relu = nn.ReLU() +# self.mlp_resnet_blocks = nn.ModuleList() +# for _ in range(num_blocks): +# self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) +# self.layer_norm2 = nn.LayerNorm(hidden_dim) +# self.dropout = nn.Dropout(drop_ratio) +# self.short_horizon = nn.Linear(hidden_dim, output_dims[0]) +# self.mid_horizon = nn.Linear(hidden_dim, output_dims[1]) +# self.long_horizon = nn.Linear(hidden_dim, output_dims[2]) +# self.base_horizon = nn.Linear(hidden_dim, output_dims[3]) + +# def forward(self, x , action_horizon_type = 'short' ): +# # x: (batch_size, input_dim) +# x = self.layer_norm1(x) # shape: (batch_size, input_dim) +# x = self.fc1(x) # shape: (batch_size, hidden_dim) +# x = self.relu(x) # shape: (batch_size, hidden_dim) +# for block in self.mlp_resnet_blocks: +# x = block(x) # shape: (batch_size, hidden_dim) +# x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) +# if self.training: +# short_actions = self.short_horizon(self.dropout(x)) +# mid_actions = self.mid_horizon(self.dropout(x)) +# long_actions = self.long_horizon(self.dropout(x)) +# base_actions = self.base_horizon(self.dropout(x)) +# return [ short_actions, mid_actions, long_actions, base_actions ] +# else: +# if action_horizon_type == 'short': +# actions = self.short_horizon(self.dropout(x)) +# elif action_horizon_type == 'mid': +# actions = self.mid_horizon(self.dropout(x)) +# elif action_horizon_type == 'long': +# actions = self.long_horizon(self.dropout(x)) +# else: +# actions = self.base_horizon(self.dropout(x)) +# return actions + + +# class MultiScaleActionHead(nn.Module): +# def __init__( +# self, +# input_dim=4096, +# hidden_dim=4096, +# action_dim=7, +# ): +# super().__init__() +# self.action_dim = action_dim +# self.horizon_dims = [ SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK, LONG_NUM_ACTIONS_CHUNK, NUM_ACTIONS_CHUNK ] +# self.heads = MultiScaleDecoder( +# num_blocks=2, input_dim=input_dim, hidden_dim=hidden_dim, +# output_dims= [ action_dim * self.horizon_dims[0] , action_dim * self.horizon_dims[1], action_dim * self.horizon_dims[2], action_dim * self.horizon_dims[3] ] +# ) +# def predict_action(self, actions_hidden_states , action_horizon_type = None): +# # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence +# # - shape: (batch_size, 1, hidden_dim) +# # ground_truth_actions: ground-truth actions +# # - shape: (batch_size, chunk_len, action_dim) +# actions = self.heads(actions_hidden_states,action_horizon_type) # (batch_size, 1, action_dim * NUM_ACTIONS_CHUNK) +# if self.training: +# for i,dim in enumerate(self.horizon_dims): +# actions[i] = actions[i].reshape(actions[i].size(0), dim, -1) # actions: list +# else: +# actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) # actions: tensor +# return actions + +# class RoboFFN(nn.Module): +# def __init__(self, dim): +# super().__init__() +# self.dim = dim +# self.norm = nn.LayerNorm(dim) +# self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers +# nn.Linear(dim, dim), +# nn.ReLU(), +# nn.Linear(dim, dim) +# ) + +# def forward(self, x): +# # x: (batch_size, hidden_dim) +# # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as +# # described here: https://arxiv.org/pdf/2002.04745.pdf +# identity = x +# x = self.norm(x) +# x = self.ffn(x) +# x = x + identity +# return x + +# class GatingMLP(nn.Module): +# def __init__(self, input_dim, hidden_dim, output_dims): +# super().__init__() +# self.norm = nn.LayerNorm(input_dim) +# self.gating = nn.Sequential( +# nn.Linear(input_dim, hidden_dim), +# nn.SiLU(), +# ) +# self.linear = nn.Linear(hidden_dim, hidden_dim) +# self.projection = nn.Linear(hidden_dim, output_dims) +# def forward(self, x): +# identity = x +# x = self.norm(x) +# x = self.gating(x) * self.linear(x) +# x = self.projection(x) +# return x + identity + +# class RobotDecoder(nn.Module): +# def __init__(self, num_blocks, input_dim, hidden_dim, output_dims, drop_ratio=0.5): +# super().__init__() +# self.gating_blocks = nn.Sequential( +# *[GatingMLP(input_dim=input_dim,hidden_dim=hidden_dim,output_dims=hidden_dim) for i in range(num_blocks)], +# ) +# self.norm = nn.LayerNorm(hidden_dim) +# self.dropout = nn.Dropout(drop_ratio) +# self.action_projection = nn.Linear(hidden_dim, output_dims) +# def forward(self, x ): +# x = self.gating_blocks(x) +# x = self.norm(x) +# return self.action_projection(self.dropout(x)) + +# class MultiScaleActionHead(nn.Module): +# def __init__( +# self, +# input_dim=4096, +# hidden_dim=4096, +# action_dim=7, +# decoder_num_blocks=2, +# ): +# super().__init__() +# self.action_dim = action_dim +# self.horizon_dims = [ SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK, NUM_ACTIONS_CHUNK ] +# self.multscaleheads = nn.ModuleList( +# [ +# RobotDecoder(num_blocks = decoder_num_blocks, input_dim=input_dim, hidden_dim=hidden_dim, output_dims=self.horizon_dims[i] *action_dim ) for i in range(len(self.horizon_dims)) +# ] +# ) +# def predict_action(self, actions_hidden_states , action_horizon_type = 0): +# # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence +# # - shape: (batch_size, 1, hidden_dim) +# # ground_truth_actions: ground-truth actions +# # - shape: (batch_size, chunk_len, action_dim) +# if self.training: +# actions = [] # actions: list +# for i,dim in enumerate(self.horizon_dims): +# action = self.multscaleheads[i](actions_hidden_states) +# action = action.reshape(action.size(0), dim, -1) +# actions.append(action) +# else: +# action = self.multscaleheads[action_horizon_type](actions_hidden_states) +# actions = actions.reshape(actions.size(0), self.horizon_dims[action_horizon_type], -1) # actions: tensor +# return actions + + +class L2Norm(nn.Module): + def __init__(self, dim=-1): + super().__init__() + self.dim = dim + def forward(self, x): + return F.normalize(x, p=2, dim=self.dim) + + +class LayerScale(nn.Module): + """可学习的层缩放模块,用于稳定深层网络训练""" + def __init__(self, dim: int, init_value: float = 1e-4): + super().__init__() + self.gamma = nn.Parameter(init_value * torch.ones(dim)) + + def forward(self, x): + return x * self.gamma + + +class ScaleNorm(nn.Module): + """ScaleNorm: 简化的LayerNorm变体,只包含缩放,适合双臂机器人""" + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + norm = x.norm(dim=-1, keepdim=True) + return self.scale * x / (norm + self.eps) + +class RoboFFN(nn.Module): + def __init__( + self, + hidden_dim: int, + ratio: float = 1.0, + ffn_type: str = "relu", + dropout: float = 0.0, + # 新增多肢体相关参数 + multi_query_norm_type: str = 'layernorm', # 'none', 'l2', 'layernorm', 'per_limb_layernorm' + num_query: int = 1, # 肢体数量,双臂机器人为2 + ): + """ + 通用 FFN 模块,支持多种非线性 / gating 机制以提升动作空间表达能力。 + + 参数说明: + hidden_dim (int): 输入 / 输出维度。 + ratio (float): 中间层放大倍数,默认 1。 + ffn_type (str): {"relu", "gelu", "gated", "swiglu"} 之一。 + dropout (float): 激活后 dropout 概率。 + multi_query_norm_type (str): 归一化类型,支持多肢体独立归一化。 + num_query (int): 肢体数量。 + """ + super().__init__() + self.dim = hidden_dim + self.ffn_type = ffn_type + self.multi_query_norm_type = multi_query_norm_type + self.num_query = num_query + + inner_dim = int(hidden_dim * ratio) + + # 根据multi_query_norm_type选择归一化方法(针对双臂机器人优化) + if multi_query_norm_type == 'per_limb_layernorm': + # 推荐:为每个肢体创建独立的LayerNorm + self.norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_query)]) + elif multi_query_norm_type == 'per_limb_rmsnorm': + # 推荐:为每个肢体创建独立的RMSNorm + self.norm = nn.ModuleList([RMSNorm(hidden_dim) for _ in range(num_query)]) + elif multi_query_norm_type == 'rmsnorm': + # 推荐:RMSNorm保持相对关系 + self.norm = RMSNorm(hidden_dim) + elif multi_query_norm_type == 'l2': + self.norm = L2Norm() + elif multi_query_norm_type == 'scalenorm': + self.norm = ScaleNorm(hidden_dim) + elif multi_query_norm_type == 'none': + self.norm = nn.Identity() + elif multi_query_norm_type == 'layernorm': + # 注意:可能导致双臂抖动 + self.norm = nn.LayerNorm(hidden_dim) + else: + # 保持原有逻辑作为后备 + self.norm = nn.LayerNorm(hidden_dim) + + self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout) + + if ffn_type in ["relu", "gelu", "silu"]: + if ffn_type == "gelu": + act_layer = nn.GELU + elif ffn_type == "silu": + act_layer = nn.SiLU + else: + act_layer = nn.ReLU + self.ffn = nn.Sequential( + nn.Linear(hidden_dim, inner_dim), + act_layer(), + self.drop, + nn.Linear(inner_dim, hidden_dim), + ) + elif ffn_type == 'norm_gelu_linear': + self.ffn = nn.Sequential( + nn.GELU(), + self.drop, + nn.Linear(inner_dim, hidden_dim), + ) + elif ffn_type == "gated": + # gate + up 合并在一张矩阵,参数量等同常见实现 + self.proj_in = nn.Linear(hidden_dim, inner_dim * 2) + self.act = nn.GELU() + self.proj_out = nn.Linear(inner_dim, hidden_dim) + elif ffn_type == "swiglu": + # 与 Llama / DeepSeek 风格一致的 SwiGLU + self.proj_in = nn.Linear(hidden_dim, inner_dim * 2, bias=False) + self.act = nn.SiLU() + self.proj_out = nn.Linear(inner_dim, hidden_dim, bias=False) + else: + raise ValueError(f"Unsupported ffn_type: {ffn_type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + # 根据归一化类型处理 + if self.multi_query_norm_type in ['per_limb_layernorm', 'per_limb_rmsnorm'] and self.num_query > 1: + # 对每个肢体分别进行归一化(LayerNorm或RMSNorm) + # 假设输入shape为 (batch_size, num_query, hidden_dim) 或 (batch_size, num_query * seq_len, hidden_dim) + batch_size = x.size(0) + if x.size(1) == self.num_query: + # Case 1: (batch_size, num_query, hidden_dim) + x_normalized = [] + for limb_idx in range(self.num_query): + limb_features = x[:, limb_idx:limb_idx+1, :] # (batch_size, 1, hidden_dim) + normalized_limb = self.norm[limb_idx](limb_features) + x_normalized.append(normalized_limb) + x = torch.cat(x_normalized, dim=1) # (batch_size, num_query, hidden_dim) + else: + # Case 2: (batch_size, num_query * seq_len, hidden_dim) + # 假设seq_len相同,重新reshape + total_tokens = x.size(1) + seq_len = total_tokens // self.num_query + x_reshaped = x.view(batch_size, self.num_query, seq_len, -1) # (batch_size, num_query, seq_len, hidden_dim) + + x_normalized = [] + for limb_idx in range(self.num_query): + limb_features = x_reshaped[:, limb_idx, :, :] # (batch_size, seq_len, hidden_dim) + normalized_limb = self.norm[limb_idx](limb_features) + x_normalized.append(normalized_limb) + x_normalized = torch.stack(x_normalized, dim=1) # (batch_size, num_query, seq_len, hidden_dim) + x = x_normalized.view(batch_size, total_tokens, -1) # 恢复原始shape + else: + # 使用标准归一化 + x = self.norm(x) + + if self.ffn_type in ["relu", "gelu", "silu", "norm_gelu_linear"]: + x = self.ffn(x) + elif self.ffn_type in ["gated", "swiglu"]: + gate_up = self.proj_in(x) # (B, *, 2H) + gate, up = gate_up.chunk(2, dim=-1) + if self.ffn_type == "gated": + inter = torch.sigmoid(gate) * up # Gated-MLP + else: # swiglu + inter = self.act(gate) * up # SwiGLU + x = self.proj_out(self.drop(inter)) + else: + raise RuntimeError() + + return x + identity + +class PostFFN(nn.Module): + def __init__(self, hidden_dim, drop_ratio = 0.1): + super().__init__() + self.dim = hidden_dim + self.norm = nn.LayerNorm(hidden_dim) + self.drop_out = nn.Dropout(drop_ratio) + self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + + def forward(self, x): + identity = x + x = self.ffn(x) + x = self.drop_out(x) + x = self.norm(x + identity) + return x + + +class GatingMLP(nn.Module): + def __init__(self, hidden_dim, drop_ratio = 0.1): + super().__init__() + self.norm = nn.LayerNorm(hidden_dim) + self.gating = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + ) + # self.drop_out = nn.Dropout(drop_ratio) + self.linear = nn.Linear(hidden_dim, hidden_dim) + self.projection = nn.Linear(hidden_dim, hidden_dim) + def forward(self, x): + identity = x + x = self.norm(x) + x = self.gating(x) * self.linear(x) + x = self.projection(x) + x = x + identity + return x + +class Expert(nn.Module): + """ + DeepSeek V3风格的专家网络,使用GELU激活函数的标准FFN + """ + def __init__(self, hidden_dim: int, intermediate_dim: int = None, dropout: float = 0.0, expansion_ratio: float = 4.0): + super().__init__() + if intermediate_dim is None: + intermediate_dim = int(hidden_dim * expansion_ratio) # 可配置的扩展倍数 + + # 标准FFN架构:linear -> gelu -> linear + self.linear1 = nn.Linear(hidden_dim, intermediate_dim, bias=True) + # self.linear2 = nn.Linear(intermediate_dim, hidden_dim, bias=True) + self.activation = nn.GELU() + # 当dropout为0时使用恒等映射,避免不必要的计算开销 + self.dropout = nn.Identity() if dropout == 0.0 else nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.activation(x) + x = self.dropout(x) + # x = self.linear2(x) + return x + + +class DeepSeekV3AdaptiveBiasRouter(nn.Module): + """DeepSeek V3的自适应偏置路由器,实现Loss-Free Balancing策略""" + def __init__( + self, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + bias_update_speed: float = 0.01, + enable_bias_correction: bool = True + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_experts = num_experts + self.top_k = top_k + self.bias_update_speed = bias_update_speed + self.enable_bias_correction = enable_bias_correction + + # 路由器权重 - 使用论文中的初始化方法 + self.router = nn.Linear(hidden_dim, num_experts, bias=False) + # 使用较小的初始化标准差,有助于训练稳定性 + nn.init.normal_(self.router.weight, mean=0, std=0.02) + + # 自适应偏置 (不参与梯度计算,符合Loss-Free Balancing原理) + if enable_bias_correction: + self.register_buffer("adaptive_bias", torch.zeros(num_experts)) + + # Loss-Free Balancing的核心:维护每个专家的频率统计 + # 这里使用EMA来追踪"recent load",符合论文描述 + self.register_buffer("expert_freq", torch.zeros(num_experts)) # f_i in paper + self.register_buffer("step_count", torch.tensor(0, dtype=torch.long)) + + def forward(self, x: torch.Tensor) -> tuple: + # x: (batch_size, seq_len, hidden_dim) + batch_size, seq_len, _ = x.shape + x_flat = x.reshape(-1, self.hidden_dim) # (batch_size * seq_len, hidden_dim) + + # 计算原始路由得分 + router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts) + + # 应用自适应偏置校正 (Loss-Free Balancing的核心) + if self.enable_bias_correction and self.training: + router_logits = router_logits + self.adaptive_bias.unsqueeze(0) + + # 论文公式(15): s_{i,t} = Sigmoid(u_t^T e_i) + sigmoid_scores = torch.sigmoid(router_logits) # (batch_size * seq_len, num_experts) + + # 论文公式(14): g'_{i,t} - Top-K选择,其他设为0 + top_k_values, top_k_indices = torch.topk(sigmoid_scores, self.top_k, dim=-1) + + # 直接对 Top-K 值进行归一化,避免构造完整稀疏矩阵 (节约显存与时间) + normalized_weights = top_k_values / (top_k_values.sum(dim=-1, keepdim=True) + 1e-8) # (batch_size * seq_len, top_k) + + # Loss-Free Balancing的负载统计更新 + if self.training: + with torch.no_grad(): + self._update_expert_frequency(top_k_indices) + self._update_adaptive_bias() + + # 重新整形回原始批次维度 + top_k_weights = normalized_weights.reshape(batch_size, seq_len, self.top_k) + top_k_expert_indices = top_k_indices.reshape(batch_size, seq_len, self.top_k) + + return top_k_weights, top_k_expert_indices + + def _update_expert_frequency(self, expert_indices: torch.Tensor): + """更新专家使用频率统计 - 实现论文中的f_i计算""" + num_tokens = expert_indices.size(0) + self.step_count += num_tokens + + # 计算当前批次中每个专家的使用次数 + expert_counts = torch.zeros_like(self.expert_freq) + for i in range(self.top_k): + indices = expert_indices[:, i] + # 确保数据类型一致,使用expert_counts的dtype而不是强制使用float + expert_counts.scatter_add_(0, indices, torch.ones_like(indices, dtype=expert_counts.dtype)) + + # 计算当前批次的专家频率 f_i = (选择次数) / (总token数 * K/N) + # 这里K/N是平均每个token选择的专家比例 + current_freq = expert_counts / (num_tokens * self.top_k / self.num_experts) + + # 使用EMA更新频率统计,体现"recent load"的概念 + alpha = min(0.1, 1.0 / max(1, self.step_count.float() / 1000)) # 自适应学习率 + self.expert_freq = (1 - alpha) * self.expert_freq + alpha * current_freq + + def _update_adaptive_bias(self): + """根据Loss-Free Balancing算法更新自适应偏置""" + if not self.enable_bias_correction: + return + + # 论文公式:b_i <- b_i - u * sign(f_i - f_avg) + # 其中f_avg = 1(理想情况下每个专家的期望频率) + f_avg = 1.0 + # 按论文中 "b_i <- b_i - u * sign(f_i - f_avg)" 更新自适应偏置 + bias_delta = self.bias_update_speed * (self.expert_freq - f_avg) + self.adaptive_bias = self.adaptive_bias - bias_delta.clamp(-0.5, 0.5) # 防爆 + + # 限制偏置范围以防止数值不稳定 + self.adaptive_bias.clamp_(-10.0, 10.0) + + def get_load_balancing_loss(self): + """计算可选的负载均衡损失(主要用于监控)""" + if not self.training: + return torch.tensor(0.0, device=self.expert_freq.device) + + # 计算专家使用频率的方差作为不平衡指标 + freq_var = self.expert_freq.var() + return freq_var + + def get_routing_stats(self): + """获取路由统计信息用于监控""" + return { + 'expert_frequencies': self.expert_freq.float().cpu().numpy().tolist(), + 'adaptive_bias': self.adaptive_bias.float().cpu().numpy().tolist(), + 'frequency_std': float(self.expert_freq.float().std()), + 'bias_std': float(self.adaptive_bias.float().std()), + 'step_count': int(self.step_count) + } + + +class MoELayer(nn.Module): + """ + DeepSeek V3风格的MoE层,实现共享专家+路由专家架构 + + 论文公式:h_t = u_t + ∑(FFN_i^(s)(u_t)) + ∑(g_{i,t} * FFN_i^(r)(u_t)) + 其中s表示shared experts,r表示routed experts + """ + def __init__( + self, + hidden_dim: int, + num_experts: int = 6, + top_k: int = 2, + expert_capacity_factor: float = 1.0, + dropout: float = 0.0, + bias_update_speed: float = 0.1, + enable_shared_expert: bool = True, # 默认启用共享专家 + num_shared_experts: int = 1, + expansion_ratio: float = 2.0 # 可配置的专家网络扩展倍数 + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_experts = num_experts + self.top_k = top_k + self.expert_capacity_factor = expert_capacity_factor + self.enable_shared_expert = enable_shared_expert + self.num_shared_experts = num_shared_experts + self.expansion_ratio = expansion_ratio + + # 专家网络的中间维度,使用可配置的扩展倍数 + intermediate_dim = int(hidden_dim * expansion_ratio) + + # 路由专家网络 + self.experts = nn.ModuleList([ + Expert(hidden_dim, intermediate_dim, dropout) + for _ in range(num_experts) + ]) + + # 共享专家(DeepSeekMoE的关键组件) + if enable_shared_expert: + self.shared_experts = nn.ModuleList([ + Expert(hidden_dim, intermediate_dim, dropout) + for _ in range(num_shared_experts) + ]) + else: + self.shared_experts = None + + # DeepSeek V3风格的自适应偏置路由器 + self.router = DeepSeekV3AdaptiveBiasRouter( + hidden_dim=hidden_dim, + num_experts=num_experts, + top_k=top_k, + bias_update_speed=bias_update_speed + ) + + # 预归一化(Pre-LayerNorm架构) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + 实现DeepSeekMoE的前向传播 + + Args: + x: (batch_size, seq_len, hidden_dim) + Returns: + output: (batch_size, seq_len, hidden_dim) + """ + batch_size, seq_len, hidden_dim = x.shape + identity = x + + # 预归一化 + x_norm = self.norm(x) + + # 1. 共享专家处理 - 所有token都经过 + shared_output = torch.zeros_like(x_norm) + if self.shared_experts is not None: + for shared_expert in self.shared_experts: + shared_output += shared_expert(x_norm) + + # 2. 路由专家处理 - 基于路由器选择 + expert_weights, expert_indices = self.router(x_norm) # (B, S, top_k), (B, S, top_k) + + # 为了提高效率,重塑输入进行批量处理 + x_flat = x_norm.reshape(-1, hidden_dim) # (B*S, H) + expert_weights_flat = expert_weights.reshape(-1, self.top_k) # (B*S, top_k) + expert_indices_flat = expert_indices.reshape(-1, self.top_k) # (B*S, top_k) + + # 初始化路由输出 + routed_output_flat = torch.zeros_like(x_flat) + + # 高效的专家处理:按专家分组而非按token分组 + for expert_idx in range(self.num_experts): + # 收集所有使用当前专家的位置和权重 + expert_mask = (expert_indices_flat == expert_idx) # (B*S, top_k) + + if expert_mask.any(): + # 获取使用当前专家的token位置和对应的权重位置 + token_indices, weight_pos = expert_mask.nonzero(as_tuple=True) + + if len(token_indices) > 0: + # 获取对应的输入和权重 + expert_input = x_flat[token_indices] # (num_selected_tokens, H) + expert_weights_selected = expert_weights_flat[token_indices, weight_pos].unsqueeze(-1) # (num_selected_tokens, 1) + + # 通过当前专家网络处理 + expert_output = self.experts[expert_idx](expert_input) # (num_selected_tokens, H) + + # 应用权重并累加到对应位置 + weighted_output = expert_weights_selected * expert_output + routed_output_flat.index_add_(0, token_indices, weighted_output) + + # 重塑回原始形状 + routed_output = routed_output_flat.reshape(batch_size, seq_len, hidden_dim) + + # 3. 按照DeepSeekMoE公式合并输出 + # h_t = u_t + ∑(FFN_i^(s)(u_t)) + ∑(g_{i,t} * FFN_i^(r)(u_t)) + final_output = identity + shared_output + routed_output + + return final_output + + def get_load_balancing_loss(self): + """获取负载均衡损失""" + return self.router.get_load_balancing_loss() + + def get_routing_stats(self): + """获取详细的路由统计信息""" + return self.router.get_routing_stats() + + +class MoERouter(nn.Module): + """ + 简化版MoE路由器,保持向后兼容 + """ + def __init__(self, hidden_dim: int, num_experts: int, top_k: int = 2): + super().__init__() + self.router = DeepSeekV3AdaptiveBiasRouter(hidden_dim, num_experts, top_k) + + def forward(self, x: torch.Tensor) -> tuple: + return self.router(x) + + +class RobotDecoder(nn.Module): + def __init__(self, num_blocks, + input_dim, + hidden_dim, + output_dims, + mlp_type = 'ffn', + ffn_type = 'relu', + proj_type= 'linear_relu', + drop_ratio=0.0, + without_action_projector=False, + action_norm="layernorm", + # MoE相关参数 + num_experts=6, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=2.0, + num_shared_experts = 1, + # use_contrastive_loss + use_contrastive_loss=False, + # 新增多肢体相关参数 + multi_query_norm_type='layernorm', # 'none', 'l2', 'layernorm', 'per_limb_layernorm' + num_query=1, # 肢体数量,双臂机器人为2 + ): # 添加扩展倍数参数 + super().__init__() + + self.use_contrastive_loss = use_contrastive_loss + self.multi_query_norm_type = multi_query_norm_type + self.num_query = num_query + + if without_action_projector: + self.hidden_projection = nn.Identity() + else: + self.hidden_projection = Query2ActionAdapter( + input_dim=input_dim, + hidden_dim=hidden_dim, + proj_type=proj_type, + multi_query_norm_type=multi_query_norm_type, + num_query=num_query, + ) + + if num_blocks == 0 : + self.mlps = nn.Identity() + else: + if mlp_type == 'ffn': + self.mlps = nn.Sequential( + *[RoboFFN(hidden_dim=hidden_dim, ffn_type = ffn_type, ratio = expansion_ratio, + multi_query_norm_type=multi_query_norm_type, num_query=num_query) for i in range(num_blocks)], + ) + elif mlp_type == 'simhead': + self.mlps = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, int(hidden_dim * expansion_ratio)), + ) + elif mlp_type == 'postffn': + self.mlps = nn.Sequential( + nn.LayerNorm(hidden_dim), + *[PostFFN(hidden_dim=hidden_dim) for i in range(num_blocks)], + ) + elif mlp_type == 'moe': + self.mlps = nn.Sequential( + *[MoELayer( + hidden_dim=hidden_dim, + num_experts=num_experts, + top_k=top_k, + expert_capacity_factor=expert_capacity_factor, + expansion_ratio=expansion_ratio, # 传递扩展倍数参数 + num_shared_experts = num_shared_experts + ) for i in range(num_blocks)], + ) + else: + self.mlps = nn.Sequential( + *[GatingMLP(hidden_dim=hidden_dim) for i in range(num_blocks)], + ) + self.action_norm = action_norm + # 根据action_norm选择归一化方法(针对双臂机器人优化) + if action_norm == 'per_limb_layernorm': + # 推荐:为每个肢体创建独立的LayerNorm,避免肢体间干扰 + self.norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_query)]) + elif action_norm == 'per_limb_rmsnorm': + # 推荐:为每个肢体创建独立的RMSNorm,避免肢体间干扰 + self.norm = nn.ModuleList([RMSNorm(hidden_dim) for _ in range(num_query)]) + elif action_norm == 'rmsnorm': + # 推荐:RMSNorm保持相对关系,减少抖动 + self.norm = RMSNorm(hidden_dim) + elif action_norm == 'layerscale': + # 推荐:轻量级可学习缩放,保持协调性 + self.norm = LayerScale(hidden_dim, init_value=1e-4) + elif action_norm == 'l2': + # 推荐:L2归一化保持相对幅度关系 + self.norm = L2Norm() + elif action_norm == 'none' or action_norm == 'identity': + # 推荐:完全不归一化,避免协调性破坏 + self.norm = nn.Identity() + elif action_norm == 'layernorm': + # 注意:可能导致双臂抖动,建议使用per_limb_layernorm替代 + self.norm = nn.LayerNorm(hidden_dim) + elif action_norm == 'insnorm': + self.norm = nn.InstanceNorm1d(hidden_dim) + elif action_norm == 'scalenorm': + # 轻量级选项:只包含缩放的简化LayerNorm + self.norm = ScaleNorm(hidden_dim) + else: + raise NotImplementedError(f"Unknown norm_type: {action_norm}. For dual-arm robots, consider: per_limb_layernorm, per_limb_rmsnorm, rmsnorm, layerscale, l2, or none") + + self.dropout = nn.Dropout(drop_ratio) if drop_ratio != 0 else nn.Identity() + self.action_projection = nn.Linear(hidden_dim, output_dims) if mlp_type != 'simhead' else nn.Linear(int(hidden_dim * expansion_ratio), output_dims) + + def forward(self, x): + x = self.hidden_projection(x) + x = self.mlps(x) + + if self.action_norm in ["per_limb_layernorm", "per_limb_rmsnorm"]: + # 对每个肢体分别进行LayerNorm或RMSNorm + # 假设输入shape为 (batch_size, num_query, hidden_dim) 或 (batch_size, num_query * seq_len, hidden_dim) + batch_size = x.size(0) + if x.size(1) == self.num_query: + # Case 1: (batch_size, num_query, hidden_dim) + x_normalized = [] + for limb_idx in range(self.num_query): + limb_features = x[:, limb_idx:limb_idx+1, :] # (batch_size, 1, hidden_dim) + normalized_limb = self.norm[limb_idx](limb_features) + x_normalized.append(normalized_limb) + x_rep = torch.cat(x_normalized, dim=1) # (batch_size, num_query, hidden_dim) + else: + # Case 2: (batch_size, num_query * seq_len, hidden_dim) + # 假设seq_len相同,重新reshape + total_tokens = x.size(1) + seq_len = total_tokens // self.num_query + x_reshaped = x.view(batch_size, self.num_query, seq_len, -1) # (batch_size, num_query, seq_len, hidden_dim) + + x_normalized = [] + for limb_idx in range(self.num_query): + limb_features = x_reshaped[:, limb_idx, :, :] # (batch_size, seq_len, hidden_dim) + normalized_limb = self.norm[limb_idx](limb_features) + x_normalized.append(normalized_limb) + x_normalized = torch.stack(x_normalized, dim=1) # (batch_size, num_query, seq_len, hidden_dim) + x_rep = x_normalized.view(batch_size, total_tokens, -1) # 恢复原始shape + + elif self.action_norm == "insnorm": + x_rep = self.norm(x.permute(0,2,1)) + x_rep = x_rep.permute(0,2,1) + else: + x_rep = self.norm(x) + + outputs = self.action_projection(self.dropout(x_rep)) + if self.use_contrastive_loss: + return outputs, x_rep + else: + return outputs + +class LatentRobotDecoder(nn.Module): + def __init__(self, num_blocks, + input_dim, + hidden_dim, + mlp_type = 'ffn', + proj_type= 'linear_relu', + # MoE相关参数 + num_experts=8, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=4.0): # 添加扩展倍数参数 + super().__init__() + self.hidden_projection = Query2ActionAdapter( + input_dim=input_dim, + hidden_dim=hidden_dim, + proj_type=proj_type, + ) + if num_blocks == 0 : + self.mlps = nn.Identity() + else: + if mlp_type == 'ffn': + self.mlps = nn.Sequential( + *[RoboFFN(hidden_dim=hidden_dim, multi_query_norm_type='layernorm', num_query=2) for i in range(num_blocks)], + ) + elif mlp_type == 'moe': + self.mlps = nn.Sequential( + *[MoELayer( + hidden_dim=hidden_dim, + num_experts=num_experts, + top_k=top_k, + expert_capacity_factor=expert_capacity_factor, + expansion_ratio=expansion_ratio # 传递扩展倍数参数 + ) for i in range(num_blocks)], + ) + else: + self.mlps = nn.Sequential( + *[GatingMLP(hidden_dim=hidden_dim) for i in range(num_blocks)], + ) + + def forward(self, x ): + x = self.hidden_projection(x) + x = self.mlps(x) + return x + + +class QueryAttnActionHead(nn.Module): + """ + 用可学习 Query + Cross-Attention 从单一 embedding 解码完整动作序列。 + """ + def __init__( + self, + input_dim: int = 4096, + hidden_dim: int = 4096, + action_dim: int = ACTION_DIM, + chunk_size: int = NUM_ACTIONS_CHUNK, + decoder_num_blocks:int=2, + mlp_type:str='ffn', + nhead: int = 8, + drop_ratio: float = 0.1, + **kwargs, + ): + super().__init__() + self.chunk_size = chunk_size + self.query_embed = nn.Parameter(torch.zeros(1, chunk_size, hidden_dim)) + # 把 backbone 的高维特征映射到 hidden_dim,注意力里用 + self.mem_proj = nn.Sequential( + nn.LayerNorm(input_dim), + nn.ReLU(), + nn.Linear(input_dim, hidden_dim) + ) + # Q×K/V 的跨注意力;因为 memory 只有 1 token,可以用较少 head + self.cross_attn = nn.MultiheadAttention(hidden_dim, nhead, batch_first=True) + self.ffn = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim,hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim,hidden_dim) + ) + # 一个很轻量的 FFN 产生动作 + self.action_predictor = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Dropout(drop_ratio), + nn.Linear(hidden_dim, action_dim), + ) + + def predict_action(self, actions_hidden_states: torch.Tensor, **kwargs): + """ + args: + actions_hidden_states: (B, 1, input_dim) —— 单一聚合 embedding + return: + actions: (B, chunk_size, action_dim) + """ + B = actions_hidden_states.size(0) + # 1) memory 投射 + mem = self.mem_proj(actions_hidden_states) # (B, 1, hidden_dim) + # 2) 拿到 query,并复制到 batch + q = self.query_embed.repeat(B, 1, 1) # (B, chunk_size, hidden_dim) + # 3) Cross-Attention + attn_out, _ = self.cross_attn(q, mem, mem) # (B, chunk_size, hidden_dim) + + outputs = self.ffn(attn_out) + # 4) FFN -> action + actions = self.action_predictor(outputs+attn_out) # (B, chunk_size, action_dim) + return actions + + +class MHActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + decoder_num_blocks=2, + mlp_type = 'ffn', + # MoE相关参数 + num_experts=8, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=4.0 # 添加扩展倍数参数 + ): + super().__init__() + self.action_dim = action_dim + self.horizon_dims = [ SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK, NUM_ACTIONS_CHUNK ] + self.latent_multi_horizon_planner = nn.ModuleList( + [ + LatentRobotDecoder(num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + mlp_type = mlp_type, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio) for i in range(len(self.horizon_dims) + ) + ] + ) + self.action_decoding = nn.ModuleList( + [ + nn.Sequential( + RoboFFN(hidden_dim=hidden_dim, multi_query_norm_type='layernorm', num_query=2), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, self.horizon_dims[i] * action_dim) + ) for i in range(len(self.horizon_dims)) + ] + ) + def predict_action(self, actions_hidden_states , num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.training: + actions = [] # actions: list + for i,dim in enumerate(self.horizon_dims): + action_latents = self.latent_multi_horizon_planner[i](actions_hidden_states) + action = self.action_decoding[i](action_latents) + action = action.reshape(action.size(0), dim, -1) + actions.append(action) + else: + action_horizon_size = self.horizon_dims.index(num_action_chunk) + action_latents = self.latent_multi_horizon_planner[action_horizon_size](actions_hidden_states) + action = self.action_decoding[action_horizon_size](action_latents) + actions = action.reshape(action.size(0), self.horizon_dims[action_horizon_size], -1) # actions: tensor + return actions + +class SharedLatentMHActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + decoder_num_blocks=2, + mlp_type = 'ffn', + # MoE相关参数 + num_experts=8, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=4.0 # 添加扩展倍数参数 + ): + super().__init__() + self.action_dim = action_dim + self.horizon_dims = [ SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK, NUM_ACTIONS_CHUNK ] + self.latent_multi_horizon_planner = LatentRobotDecoder(num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + mlp_type = mlp_type, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio) # 传递扩展倍数参数 + + self.action_decoding = nn.ModuleList( + [ + nn.Sequential( + RoboFFN(hidden_dim=hidden_dim, multi_query_norm_type='layernorm', num_query=2), + RoboFFN(hidden_dim=hidden_dim, multi_query_norm_type='layernorm', num_query=2), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, self.horizon_dims[i] * action_dim) + ) for i in range(len(self.horizon_dims)) + ] + ) + def predict_action(self, actions_hidden_states , num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.training: + actions = [] # actions: list + action_latents = self.latent_multi_horizon_planner(actions_hidden_states) + for i,dim in enumerate(self.horizon_dims): + action = self.action_decoding[i](action_latents) + action = action.reshape(action.size(0), dim, -1) + actions.append(action) + else: + action_horizon_size = self.horizon_dims.index(num_action_chunk) + action_latents = self.latent_multi_horizon_planner(actions_hidden_states) + action = self.action_decoding[action_horizon_size](action_latents) + actions = action.reshape(action.size(0), self.horizon_dims[action_horizon_size], -1) # actions: tensor + return actions + + +class MultiScaleActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + decoder_num_blocks=2, + mlp_type = 'ffn', + # MoE相关参数 + num_experts=8, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=4.0 # 添加扩展倍数参数 + ): + super().__init__() + self.action_dim = action_dim + self.horizon_dims = [ SHORT_NUM_ACTIONS_CHUNK, MID_NUM_ACTIONS_CHUNK, NUM_ACTIONS_CHUNK ] + self.multscaleheads = nn.ModuleList( + [ + RobotDecoder(num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + output_dims = self.horizon_dims[i] * action_dim, + mlp_type = mlp_type, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio) for i in range(len(self.horizon_dims) + ) + ] + ) + def predict_action(self, actions_hidden_states , action_horizon_type = 0): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.training: + actions = [] # actions: list + for i,dim in enumerate(self.horizon_dims): + action = self.multscaleheads[i](actions_hidden_states[:, i:i+1]) + action = action.reshape(action.size(0), dim, -1) + actions.append(action) + else: + action = self.multscaleheads[action_horizon_type](actions_hidden_states) + actions = actions.reshape(actions.size(0), self.horizon_dims[action_horizon_type], -1) # actions: tensor + return actions + + + +class TSActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks = 2, + proj_type='gelu_linear', + mlp_type = 'ffn', + ffn_type = 'gelu', + drop_ratio = 0.1, + without_action_projector=False, + action_norm="layernorm", + # MoE相关参数 + num_experts=6, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=2.0, # 添加扩展倍数参数 + num_shared_experts = 1, + use_contrastive_loss=False, + multi_query_norm_type="layernorm", + num_query=1, + **kwargs + ): + super().__init__() + self.chunk_size = chunk_size + self.use_contrastive_loss=use_contrastive_loss + self.action_dim = 7 + self.head = RobotDecoder( num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + output_dims = NUM_ACTIONS_CHUNK*self.action_dim , + mlp_type = mlp_type, + proj_type = proj_type, + ffn_type = ffn_type, + drop_ratio = drop_ratio, + without_action_projector=without_action_projector, + action_norm = action_norm, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio, + num_shared_experts = num_shared_experts, + use_contrastive_loss=use_contrastive_loss, + multi_query_norm_type=multi_query_norm_type, + num_query=num_query) # 传递扩展倍数参数 + + def predict_action(self, actions_hidden_states, num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.use_contrastive_loss: + actions, action_rep = self.head(actions_hidden_states) + actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) + return actions, action_rep + else: + actions = self.head(actions_hidden_states) # (batch_size, 2, action_dim * NUM_ACTIONS_CHUNK) + # actions = rearrange(actions,"b l d -> b d l") + b,l,a = actions.size() + # actions = rearrange(actions,"b l (t d) -> b t (l d)", b =b, l=l, t= NUM_ACTIONS_CHUNK, d = self.action_dim) + actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) + return actions + + +class TActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks = 2, + proj_type='gelu_linear', + mlp_type = 'ffn', + ffn_type = 'gelu', + drop_ratio = 0.1, + without_action_projector=False, + action_norm="layernorm", + # MoE相关参数 + num_experts=6, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=2.0, # 添加扩展倍数参数 + num_shared_experts = 1, + use_contrastive_loss=False, + multi_query_norm_type="layernorm", + num_query=1, + **kwargs + ): + super().__init__() + self.chunk_size = chunk_size + self.use_contrastive_loss=use_contrastive_loss + self.head = RobotDecoder( num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + output_dims = NUM_ACTIONS_CHUNK, + mlp_type = mlp_type, + proj_type = proj_type, + ffn_type = ffn_type, + drop_ratio = drop_ratio, + without_action_projector=without_action_projector, + action_norm = action_norm, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio, + num_shared_experts = num_shared_experts, + use_contrastive_loss=use_contrastive_loss, + multi_query_norm_type=multi_query_norm_type, + num_query=num_query) # 传递扩展倍数参数 + + def predict_action(self, actions_hidden_states, num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.use_contrastive_loss: + actions, action_rep = self.head(actions_hidden_states) + actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) + return actions, action_rep + else: + actions = self.head(actions_hidden_states) # (batch_size, Action_dim , NUM_ACTIONS_CHUNK) + # actions = rearrange(actions,"b d l -> b l d") + actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) + return actions + + +class SActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks = 2, + proj_type='gelu_linear', + mlp_type = 'ffn', + ffn_type = 'gelu', + drop_ratio = 0.1, + without_action_projector=False, + action_norm="layernorm", + # MoE相关参数 + num_experts=6, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=2.0, # 添加扩展倍数参数 + num_shared_experts = 1, + use_contrastive_loss=False, + multi_query_norm_type="layernorm", + num_query=1, + **kwargs + ): + super().__init__() + self.chunk_size = chunk_size + self.use_contrastive_loss=use_contrastive_loss + self.head = RobotDecoder( num_blocks = decoder_num_blocks, + input_dim = input_dim, + hidden_dim = hidden_dim, + output_dims = action_dim, + mlp_type = mlp_type, + proj_type = proj_type, + ffn_type = ffn_type, + drop_ratio = drop_ratio, + without_action_projector=without_action_projector, + action_norm = action_norm, + num_experts = num_experts, + top_k = top_k, + expert_capacity_factor = expert_capacity_factor, + expansion_ratio = expansion_ratio, + num_shared_experts = num_shared_experts, + use_contrastive_loss=use_contrastive_loss, + multi_query_norm_type=multi_query_norm_type, + num_query=num_query) # 传递扩展倍数参数 + + def predict_action(self, actions_hidden_states, num_action_chunk = 8): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, 1, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + if self.use_contrastive_loss: + actions, action_rep = self.head(actions_hidden_states) + return actions, action_rep + else: + actions = self.head(actions_hidden_states) # (batch_size, NUM_ACTIONS_CHUNK , action_dim ) + # actions = actions.reshape(actions.size(0), NUM_ACTIONS_CHUNK, -1) + return actions + + +class MultiGranularityTSActionHead(nn.Module): + """ + Multi-granularity action head based on TSActionHead structure. + Fine-grained actions are extracted based on coarse-grained actions. + """ + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks=2, + mlp_type='ffn' + ): + super().__init__() + self.chunk_size = chunk_size + self.action_dim = action_dim + + self.coarse_hidden_projection = nn.Sequential( + nn.LayerNorm(input_dim), + nn.ReLU(), + nn.Linear(input_dim, hidden_dim), + *[RoboFFN(hidden_dim=hidden_dim, multi_query_norm_type='layernorm', num_query=2) for i in range(decoder_num_blocks)] + ) + + # 粗粒度动作头 (类似原始TSActionHead) + self.coarse_head = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Dropout(0.1), + nn.Linear(hidden_dim, chunk_size*action_dim) + ) + + # 多尺度卷积层直接在粗粒度actions上捕捉细粒度特征 + self.multi_scale_convs = nn.ModuleList([ + nn.Sequential( + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=k, padding=k//2), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True) + ) + for k in [3, 5, 7] + ]) + + # 融合层:Conv1×1 + BN(无激活,保持线性,适合回归) + self.feature_fusion = nn.Sequential( + nn.Conv1d(hidden_dim * len(self.multi_scale_convs), hidden_dim, kernel_size=1), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + ) + + # 最终线性层:预测 residual(Δ),随后与 coarse 动作相加得到 fine 动作 + self.out_linear = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Dropout(0.1), + nn.Linear(hidden_dim, chunk_size*action_dim) + ) + + + def predict_action(self, actions_hidden_states, num_action_chunk=8): + """ + 预测粗粒度和细粒度动作 + + Args: + actions_hidden_states: (batch_size, 1, input_dim) + + Returns: + dict: { + 'coarse_actions': (batch_size, chunk_size, action_dim) + 'fine_actions': (batch_size, chunk_size, action_dim) + } + """ + batch_size = actions_hidden_states.shape[0] + + # 1. 粗粒度动作预测 (使用原始TSActionHead结构) + coarse_features = self.coarse_hidden_projection(actions_hidden_states) + coarse_actions = self.coarse_head(coarse_features) + coarse_actions = coarse_actions.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) + + # 2. 直接在粗粒度actions上进行多尺度卷积 + # 转换为卷积格式: (batch_size, hidden_dim, chunk_size) + conv_input = coarse_features.permute(0, 2, 1) + + # 3. 多尺度卷积处理粗粒度actions + multi_scale_features = [] + for conv in self.multi_scale_convs: + multi_scale_features.append(conv(conv_input)) + + # 4. 融合多尺度特征 + # 拼接所有尺度的特征: (B, action_dim * num_scales, chunk_size) + fused_features = torch.cat(multi_scale_features, dim=1) + fine_actions_conv = self.feature_fusion(fused_features) # (B, action_dim, chunk_size) + + # 转换回序列格式: (B, chunk_size, action_dim) + fine_actions = fine_actions_conv.permute(0, 2, 1) + + # 计算 residual,再与 coarse 动作相加形成细粒度动作 + fine_actions_delta = self.out_linear(fine_actions) + fine_actions = coarse_actions + fine_actions_delta + + return { + 'coarse_actions': coarse_actions, + 'fine_actions': fine_actions + } + + + +class SimTSActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + ): + super().__init__() + self.action_dim = action_dim + self.memory_ffn = nn.Sequential( + nn.Linear(input_dim,hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim,hidden_dim) + ) + self.action_projection = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(hidden_dim,NUM_ACTIONS_CHUNK) + ) + def predict_action(self, actions_hidden_states): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, action_dim, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + actions = self.action_projection(self.memory_ffn(actions_hidden_states)) + return actions.permute(0, 2, 1) # (batch_size, chunk_len, action_dim) + + + +class NoisePredictionModel(nn.Module): + """ + Diffusion noise prediction model that takes an observation embedding (which fuses the + noisy action, diffusion timestep, and image-language observation embeddings) and + outputs a noise prediction. + """ + + def __init__( + self, + transformer_hidden_dim, # Transformer hidden embedding size + hidden_dim, # MLP hidden size + action_dim=7, # action dimensionality + ): + super().__init__() + self.mlp_resnet = MLPResNet( + num_blocks=2, + input_dim=transformer_hidden_dim, + hidden_dim=hidden_dim, + output_dim=action_dim, + ) + + def forward( + self, + obs, + ): + # obs: observation embeddings to condition the generation on + # - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim) + # + # output: predicted noise + # - shape: (batch_size, action_dim) + output = self.mlp_resnet(obs) + return output + + +class DiffusionActionHead(nn.Module): + """ + Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process. + + Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py + """ + + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + num_diffusion_steps=100, + ): + super().__init__() + self.action_dim = action_dim + self.noise_predictor = NoisePredictionModel( + transformer_hidden_dim=hidden_dim*ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim + ) + self.noise_scheduler = DDIMScheduler(num_train_timesteps=num_diffusion_steps, beta_schedule="squaredcos_cap_v2") + self.num_diffusion_steps = num_diffusion_steps + self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim) + + def sample_noisy_actions(self, ground_truth_actions): + """ + Samples noise and applies noise to ground-truth actions to produce noisy actions, which are + used as input in the noise prediction network. Returns noise, noisy actions, and the + corresponding diffusion timestep embeddings. + """ + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + batch_size = ground_truth_actions.shape[0] + device = ground_truth_actions.device + # Sample random noise with shape equal to actions, used for closed-form forward diffusion. + noise = torch.randn(size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype) # (B, chunk_len, action_dim) + # Sample random diffusion timesteps (one for each action in batch). + timesteps = torch.randint( + low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device + ) + # Add noise to clean actions according to the magnitude at each diffusion timestep via + # closed-form forward diffusion. + noisy_actions = self.noise_scheduler.add_noise(ground_truth_actions, noise, timesteps) # (B, chunk_len, action_dim) + + # Get diffusion timestep embeddings as well + diffusion_timestep_embeddings = self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device) # (B, llm_dim) + diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) + + return_dict = dict( + noise=noise, + noisy_actions=noisy_actions, + diffusion_timestep_embeddings=diffusion_timestep_embeddings, + ) + + return return_dict + + def predict_noise(self, actions_hidden_states): + """ + Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings, + noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions. + """ + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, chunk_len * action_dim, hidden_dim) + batch_size = actions_hidden_states.shape[0] + device = actions_hidden_states.device + rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) # (batch_size, chunk_len, action_dim * hidden_dim) + # Get diffusion model's noise prediction. + noise_pred = self.noise_predictor(rearranged_actions_hidden_states) + return noise_pred + + +class TemporalTransformerActionHead(nn.Module): + """基于 Transformer 编码器的动作序列预测 Head。 + + 该模块首先将每个时间步的隐藏状态(跨 action_dim 的拼接)映射到较低维的时序 embedding, + 随后利用多层自注意力对时间维度进行建模,最后再映射回动作空间。 + + 相比纯 MLP,这里显式考虑了时间相关性,从而在长序列或跨任务泛化时更具优势。 + """ + + def __init__( + self, + input_dim: int = 4096, + hidden_dim: int = 256, + action_dim: int = ACTION_DIM, + num_layers: int = 4, + nhead: int = 8, + dim_feedforward: int = 512, + dropout: float = 0.1, + predicted_dropout: float = 0.4, + ) -> None: + """参数说明 + Args: + input_dim: Transformer backbone 的隐藏维度。(即传入的 actions_hidden_states 的最后一维) + hidden_dim: 时间序列 Transformer 的内部嵌入维度 (d_model)。 + action_dim: 机器人的动作维度。 + num_layers: TransformerEncoderLayer 的层数。 + nhead: 多头注意力的头数。 + dim_feedforward: TransformerEncoderLayer 前馈网络维度。 + dropout: dropout 概率。 + """ + super().__init__() + + # 当前输入 token 数量 = ACTION_DIM + self.action_dim = action_dim + + # 将每个 action token 的高维表示映射到较低维 d_model,减少计算量 + self.input_projection = nn.Sequential( + nn.Linear(input_dim, input_dim), + nn.ReLU(), + nn.Linear(input_dim, hidden_dim) + ) + # 针对 ACTION_DIM 个 token 的可学习位置编码(顺序固定,因此长度=ACTION_DIM) + self.pos_embedding = nn.Parameter( + torch.zeros(1, ACTION_DIM, hidden_dim), requires_grad=True + ) + + # Transformer 编码器 + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + batch_first=True, + activation="gelu", + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + self.dropout = nn.Dropout(predicted_dropout) + + # 输出映射到 action_dim + self.output_projection = nn.Linear(hidden_dim, NUM_ACTIONS_CHUNK) + + # 初始化 + self._reset_parameters() + + def _reset_parameters(self): + nn.init.trunc_normal_(self.pos_embedding, std=0.02) + # Linear 层默认初始化即可 + + def predict_action(self, actions_hidden_states: torch.Tensor) -> torch.Tensor: + """预测动作序列。 + + Args: + actions_hidden_states: Transformer 最后一层对应 action token 的隐藏状态, + 形状为 (batch_size, ACTION_DIM, input_dim) + + Returns: + 预测的动作序列,形状为 (batch_size, NUM_ACTIONS_CHUNK, action_dim) + """ + B, A, D = actions_hidden_states.shape # A == ACTION_DIM + assert A == ACTION_DIM, ( + "actions_hidden_states 的第二维应当等于 ACTION_DIM," \ + f"但获得 {A} 与 {ACTION_DIM} 不符" + ) + + # 对每个 action token 做线性降维 + x = self.input_projection(actions_hidden_states) # (B, ACTION_DIM, hidden_dim) + + # 加上可学习位置编码 + x = x + self.pos_embedding[:, :ACTION_DIM, :] + + # Transformer 编码器 (batch_first=True) + x = self.transformer_encoder(x) # (B, ACTION_DIM, hidden_dim) + + # 将隐藏表示映射为长度 NUM_ACTIONS_CHUNK 的时间序列 + actions = self.output_projection(self.dropout(x)) # (B, ACTION_DIM, NUM_ACTIONS_CHUNK) + + # 调整维度为 (B, NUM_ACTIONS_CHUNK, ACTION_DIM) + actions = actions.permute(0, 2, 1) + return actions + + +class TemporalConvActionHead(nn.Module): + """基于一维卷积(Temporal Convolution Network)的动作序列预测 Head。 + + 通过多层膨胀卷积捕获长程依赖,相比 Transformer 计算量更低, + 在数据量较小时具有更好的泛化与稳定性。 + """ + + def __init__( + self, + input_dim: int = 4096, + action_dim: int = ACTION_DIM, + hidden_dim: int = 512, + num_layers: int = 4, + kernel_size: int = 3, + dropout: float = 0.1, + predicted_dropout: float = 0.4, + ) -> None: + super().__init__() + self.action_dim = action_dim + + # 卷积通道维度 = input_dim,序列长度 = ACTION_DIM + layers = [] + in_channels = input_dim + dilation = 1 + for _ in range(num_layers): + layers.append( + nn.Sequential( + nn.Conv1d( + in_channels, + hidden_dim, + kernel_size, + padding=(kernel_size - 1) * dilation // 2, + dilation=dilation, + ), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + ) + ) + in_channels = hidden_dim + dilation *= 2 + self.tcn = nn.Sequential(*layers) + self.dropout = nn.Dropout(predicted_dropout) + # 最终 1x1 卷积将 hidden_dim -> NUM_ACTIONS_CHUNK,得到时间序列长度 + self.fc_out = nn.Conv1d(hidden_dim, NUM_ACTIONS_CHUNK, kernel_size=1) + + def predict_action(self, actions_hidden_states: torch.Tensor) -> torch.Tensor: + """预测动作序列。 + + Args: + actions_hidden_states: 形状 (B, ACTION_DIM, input_dim) + + Returns: + 形状 (B, NUM_ACTIONS_CHUNK, action_dim) + """ + B, A, D = actions_hidden_states.shape + assert A == ACTION_DIM, ( + "actions_hidden_states 的第二维应当等于 ACTION_DIM," \ + f"但获得 {A} 与 {ACTION_DIM} 不符" + ) + + # 重新排列为 (B, input_dim, ACTION_DIM) 以便进行 1D 卷积 + x = actions_hidden_states.permute(0, 2, 1) # (B, D, A) + x = self.tcn(x) # (B, hidden_dim, A) + + # 生成时间序列: (B, NUM_ACTIONS_CHUNK, ACTION_DIM) + actions = self.fc_out(self.dropout(x)) # (B, NUM_ACTIONS_CHUNK, A) + + # 输出形状 (B, NUM_ACTIONS_CHUNK, ACTION_DIM) + return actions + + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + +class DLinear(nn.Module): + """ + DLinear + """ + def __init__(self, individual = False, enc_in=7, kernel_size = 5): + super(DLinear, self).__init__() + self.seq_len = NUM_ACTIONS_CHUNK + self.pred_len = NUM_ACTIONS_CHUNK + + # Decompsition Kernel Size + kernel_size = kernel_size + self.decompsition = series_decomp(kernel_size) + self.individual = individual + self.channels = enc_in + + if self.individual: + self.Linear_Seasonal = nn.ModuleList() + self.Linear_Trend = nn.ModuleList() + self.Linear_Decoder = nn.ModuleList() + for i in range(self.channels): + self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len)) + self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len)) + self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + self.Linear_Decoder.append(nn.Linear(self.seq_len,self.pred_len)) + else: + self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) + self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) + self.Linear_Decoder = nn.Linear(self.seq_len,self.pred_len) + self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + + def forward(self, x): + # x: [Batch, Input length, Channel] + seasonal_init, trend_init = self.decompsition(x) + seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) + if self.individual: + seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device) + trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device) + for i in range(self.channels): + seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:]) + trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:]) + else: + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + + x = seasonal_output + trend_output + return x.permute(0,2,1) # to [Batch, Output length, Channel] + +class L1DlinearActionHead(nn.Module): + """Dlinear-based action head for continuous action prediction.""" + def __init__( + self, + input_dim=4096, + hidden_dim=512, + kernel_size = 5, + individual = True, + ): + super().__init__() + self.input_dim = input_dim + + # 将每个时间步的高维特征降到 ACTION_DIM,以便喂给 DLinear + self.action_enc = nn.Sequential( + nn.Linear(input_dim, input_dim), + nn.LayerNorm(input_dim), + nn.GELU(), + nn.Linear(input_dim, hidden_dim), + ) + + # 时序建模 + self.model = DLinear(individual=individual, enc_in=ACTION_DIM, kernel_size=kernel_size) + + def predict_action(self, actions_hidden_states): + # actions_hidden_states: (B, ACTION_DIM, hidden_dim) + x = self.action_enc(actions_hidden_states) # (B, T, ACTION_DIM) + + # 时序建模 + x = self.model(x) # (B, T, ACTION_DIM) + + return x # (B, NUM_ACTIONS_CHUNK, ACTION_DIM) + +class DeepSeekV3MoEActionHead(nn.Module): + """基于DeepSeek V3 MoE架构的动作预测头 + + 特点: + 1. 共享专家 + 路由专家架构(可选) + 2. 自适应偏置校正(无需辅助损失) + 3. Sigmoid激活的路由器 + 4. 高效的专家并行计算 + 5. GELU激活的FFN专家网络 + """ + def __init__( + self, + input_dim: int = 4096, + hidden_dim: int = 1024, + action_dim: int = ACTION_DIM, + num_routed_experts: int = 16, # 适度的专家数量 + num_shared_experts: int = 1, + top_k: int = 2, # 每个token激活2个路由专家 + num_moe_layers: int = 2, + dropout: float = 0.1, + bias_update_speed: float = 0.01, + enable_load_balancing: bool = True, + enable_shared_expert: bool = False, + expansion_ratio: float = 4.0 # 添加扩展倍数参数 + ): + super().__init__() + self.action_dim = action_dim + self.num_moe_layers = num_moe_layers + self.enable_load_balancing = enable_load_balancing + + # 输入投影 - 将 action token embeddings 转换为 MoE 隐藏维度 + self.input_projection = nn.Sequential( + nn.LayerNorm(input_dim), + nn.Linear(input_dim, hidden_dim), + nn.GELU() + ) + + # MoE层堆叠 + self.moe_layers = nn.ModuleList([ + MoELayer( + hidden_dim=hidden_dim, + num_experts=num_routed_experts, + top_k=top_k, + dropout=dropout, + bias_update_speed=bias_update_speed, + enable_shared_expert=enable_shared_expert, + num_shared_experts=num_shared_experts, + expansion_ratio=expansion_ratio # 传递扩展倍数参数 + ) + for _ in range(num_moe_layers) + ]) + + # 输出投影 + self.output_projection = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Identity() if dropout == 0.0 else nn.Dropout(dropout), + nn.Linear(hidden_dim, NUM_ACTIONS_CHUNK * action_dim) + ) + + def predict_action(self, actions_hidden_states: torch.Tensor) -> torch.Tensor: + """预测动作序列 + + Args: + actions_hidden_states: Transformer最后一层对应action token的隐藏状态 + 形状为 (batch_size, ACTION_DIM, input_dim) 或 (batch_size, 1, input_dim) + + Returns: + 预测动作,形状为 (batch_size, NUM_ACTIONS_CHUNK, action_dim) + """ + B = actions_hidden_states.size(0) + + # 处理不同的输入形状 + if actions_hidden_states.size(1) == ACTION_DIM: + # 形状: (B, ACTION_DIM, input_dim) -> (B, ACTION_DIM, hidden_dim) + x = self.input_projection(actions_hidden_states) + else: + # 形状: (B, 1, input_dim) -> (B, 1, hidden_dim) + x = self.input_projection(actions_hidden_states) + + # 通过MoE层 + for moe_layer in self.moe_layers: + x = moe_layer(x) + + # 输出投影 + if x.size(1) == 1: + # 如果输入是单个token,输出整个动作序列 + actions = self.output_projection(x.squeeze(1)) # (B, NUM_ACTIONS_CHUNK * action_dim) + actions = actions.reshape(B, NUM_ACTIONS_CHUNK, self.action_dim) + else: + # 如果输入是多个token,每个输出一个动作维度 + actions = self.output_projection(x) # (B, ACTION_DIM, NUM_ACTIONS_CHUNK * action_dim) + # 重新排列为时间序列格式 + actions = actions.reshape(B, ACTION_DIM, NUM_ACTIONS_CHUNK, self.action_dim) + actions = actions.permute(0, 2, 1, 3) # (B, NUM_ACTIONS_CHUNK, ACTION_DIM, action_dim) + # 假设我们只取第一个动作维度(或可以做平均、加权等) + actions = actions.mean(dim=2) # (B, NUM_ACTIONS_CHUNK, action_dim) + + return actions + + def get_load_balancing_loss(self): + """获取所有MoE层的负载均衡损失""" + if not self.enable_load_balancing: + return torch.tensor(0.0) + + total_loss = torch.tensor(0.0) + for moe_layer in self.moe_layers: + total_loss += moe_layer.get_load_balancing_loss() + + return total_loss / len(self.moe_layers) + + def get_expert_usage_stats(self): + """获取专家使用统计信息(用于监控和调试)""" + stats = {} + for i, moe_layer in enumerate(self.moe_layers): + layer_stats = moe_layer.get_routing_stats() + stats[f'layer_{i}'] = layer_stats + + return stats + +# 添加adaLN-Zero相关的模块 +class AdaLNZeroConditioner(nn.Module): + """ + 文本条件化器,将文本特征映射为adaLN-Zero的调制参数 + """ + def __init__(self, hidden_dim: int, text_dim: int): + super().__init__() + self.hidden_dim = hidden_dim + self.text_dim = text_dim + + # 文本特征编码器 + self.text_encoder = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim * 3) # 输出scale, shift, gate三个参数 + ) + + # 初始化:gate参数初始化为0,实现zero初始化 + with torch.no_grad(): + # 将gate部分的权重和偏置初始化为0 + self.text_encoder[-1].weight[-hidden_dim:].zero_() + self.text_encoder[-1].bias[-hidden_dim:].zero_() + + def forward(self, text_hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + text_hidden_states: (B, text_seq_len, text_dim) 文本部分的hidden states,text_seq_len可变 + Returns: + condition_params: (B, hidden_dim * 3) 调制参数 [scale, shift, gate] + """ + # 直接对文本hidden states做平均池化 + text_features = text_hidden_states.mean(dim=1) # (B, text_dim) + # 生成调制参数 + condition_params = self.text_encoder(text_features) # (B, hidden_dim * 3) + return condition_params + + +class AdaLNZeroBlock(nn.Module): + """ + 应用adaLN-Zero的FFN块(仅FFN,无attention) + """ + def __init__(self, + hidden_dim: int, + text_dim: int, + ffn_type: str = 'relu', + ratio: float = 2.0, + action_norm: str = "layernorm", + dropout: float = 0.0, + + ): + super().__init__() + self.hidden_dim = hidden_dim + + # 标准的FFN组件 + self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False) # 无仿射变换的LayerNorm + self.ffn = RoboFFN(hidden_dim, ratio, ffn_type, dropout, multi_query_norm_type='layernorm', num_query=2) + + # adaLN-Zero条件化器 + self.conditioner = AdaLNZeroConditioner(hidden_dim, text_dim) + + def forward(self, x: torch.Tensor, text_condition: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, seq_len, hidden_dim) 输入特征 + text_condition: (B, text_seq_len, text_dim) 文本条件 + attention_mask: (B, text_seq_len) 可选的attention mask + + Returns: + output: (B, seq_len, hidden_dim) 输出特征 + """ + # 获取调制参数 + condition_params = self.conditioner(text_condition) # (B, hidden_dim * 3) + + # 分解调制参数:scale, shift, gate + scale, shift, gate = condition_params.chunk(3, dim=-1) # 每个都是 (B, hidden_dim) + + # 扩展维度以匹配输入 + scale = scale.unsqueeze(1) # (B, 1, hidden_dim) + shift = shift.unsqueeze(1) # (B, 1, hidden_dim) + gate = gate.unsqueeze(1) # (B, 1, hidden_dim) + + # 应用adaLN-Zero到FFN + # 1. 标准化(无仿射变换) + normed_x = self.norm(x) # (B, seq_len, hidden_dim) + + # 2. 应用条件化的scale和shift + conditioned_x = normed_x * (1 + scale) + shift # (B, seq_len, hidden_dim) + + # 3. 通过FFN + ffn_output = self.ffn(conditioned_x) # (B, seq_len, hidden_dim) + + # 4. 应用gate并添加残差连接 + output = x + gate * ffn_output # (B, seq_len, hidden_dim) + + return output + + +class AdaLNZeroRobotDecoder(nn.Module): + """ + 支持adaLN-Zero条件化的机器人动作解码器 + """ + def __init__(self, + num_blocks: int, + input_dim: int, + hidden_dim: int, + text_dim: int, # 新增:文本特征维度 + output_dims: int, + mlp_type: str = 'adaln_zero', + ffn_type: str = 'relu', + proj_type: str = 'linear_relu', + drop_ratio: float = 0.1, + without_action_projector: bool = False, + expansion_ratio: float = 2.0, + action_norm: str = "layernorm"): + super().__init__() + + self.num_blocks = num_blocks + self.text_dim = text_dim + + # 输入投影 + if without_action_projector: + self.hidden_projection = nn.Identity() + else: + self.hidden_projection = Query2ActionAdapter( + input_dim=input_dim, + hidden_dim=hidden_dim, + proj_type=proj_type, + ) + + # 主要的处理层 + if num_blocks == 0: + self.mlps = nn.Identity() + elif mlp_type == 'adaln_zero': + # 使用adaLN-Zero调制的块 + self.mlps = nn.ModuleList([ + AdaLNZeroBlock( + hidden_dim=hidden_dim, + text_dim=text_dim, + ffn_type=ffn_type, + ratio=expansion_ratio, + action_norm=action_norm + ) for _ in range(num_blocks) + ]) + else: + # 保持原有的实现方式作为后备 + if mlp_type == 'ffn': + self.mlps = nn.Sequential( + *[RoboFFN(hidden_dim=hidden_dim, ffn_type=ffn_type, ratio=expansion_ratio, multi_query_norm_type='layernorm', num_query=2) for _ in range(num_blocks)] + ) + # ... 其他mlp_type的实现保持不变 + + # 输出层 + self.norm = L2Norm() if action_norm else nn.LayerNorm(hidden_dim) + self.dropout = nn.Dropout(drop_ratio) if drop_ratio != 0 else nn.Identity() + self.action_projection = nn.Linear(hidden_dim, output_dims) + + def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: (B, seq_len, input_dim) 动作相关的hidden states + text_condition: (B, text_seq_len, text_dim) 文本指令的hidden states + attention_mask: (B, text_seq_len) 可选的attention mask + + Returns: + actions: (B, seq_len, output_dims) 预测的动作 + """ + # 输入投影 + x = self.hidden_projection(x) + + # 主要处理 + if condition is not None: + # 使用adaLN-Zero调制 + for block in self.mlps: + x = block(x, condition) + + # 输出 + x = self.norm(x) + x = self.action_projection(self.dropout(x)) + + return x + +class AdaLNZeroTSActionHead(nn.Module): + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + text_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks=2, + proj_type='gelu_linear', + mlp_type='adaln_zero', + ffn_type='gelu', + drop_ratio=0.1, + without_action_projector=False, + expansion_ratio=2.0, + use_visualcondition=False, # 新增参数 + action_norm=False, + **kwargs + ): + super().__init__() + self.action_dim = action_dim + self.chunk_size = chunk_size + self.text_dim = text_dim + self.use_visualcondition = use_visualcondition + + self.head = AdaLNZeroRobotDecoder( + num_blocks=decoder_num_blocks, + input_dim=input_dim, + hidden_dim=hidden_dim, + text_dim=text_dim, + output_dims=action_dim * chunk_size, + mlp_type=mlp_type, + ffn_type=ffn_type, + proj_type=proj_type, + drop_ratio=drop_ratio, + without_action_projector=without_action_projector, + expansion_ratio=expansion_ratio, + action_norm=action_norm + ) + + def predict_action( + self, + actions_hidden_states, + text_hidden_states=None, + visual_condition=None, # 新增参数 + num_action_chunk=8 + ): + """ + Args: + actions_hidden_states: (B, 1, input_dim) + text_hidden_states: (B, text_seq_len, text_dim) + visual_condition: (B, vis_seq_len, vis_dim) 视觉latents + num_action_chunk: int + """ + # 根据use_visualcondition选择条件 + if self.use_visualcondition: + condition = visual_condition + else: + condition = text_hidden_states + + actions = self.head(actions_hidden_states, condition=condition) + actions = actions.reshape(actions.size(0), self.chunk_size, -1) + return actions + + + +class DualArmTSActionHead(nn.Module): + """ + 专门为双臂机器人设计的Action Head,分离处理两个token + 输入: (B, 2, hidden_dim) - 第0个token控制左臂,第1个token控制右臂 + """ + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + chunk_size=8, + decoder_num_blocks=2, + proj_type='gelu_linear', + mlp_type='ffn', + ffn_type='gelu', + drop_ratio=0.1, + without_action_projector=False, + action_norm="layernorm", + # MoE相关参数 + num_experts=6, + top_k=2, + expert_capacity_factor=1.0, + expansion_ratio=2.0, + num_shared_experts=1, + use_contrastive_loss=False, + **kwargs + ): + super().__init__() + self.chunk_size = chunk_size + self.action_dim = 7 + self.use_contrastive_loss = use_contrastive_loss + + # 为左臂创建独立的解码器(处理第0个token) + self.left_arm_decoder = RobotDecoder( + num_blocks=decoder_num_blocks, + input_dim=input_dim, + hidden_dim=hidden_dim, + output_dims=NUM_ACTIONS_CHUNK * self.action_dim, + mlp_type=mlp_type, + proj_type=proj_type, + ffn_type=ffn_type, + drop_ratio=drop_ratio, + without_action_projector=without_action_projector, + action_norm=action_norm, + num_experts=num_experts, + top_k=top_k, + expert_capacity_factor=expert_capacity_factor, + expansion_ratio=expansion_ratio, + num_shared_experts=num_shared_experts, + use_contrastive_loss=False, # 单独的decoder不需要contrastive loss + multi_query_norm_type='layernorm', # 每个arm内部使用标准LayerNorm + num_query=1 # 每个decoder只处理一个arm + ) + + # 为右臂创建独立的解码器(处理第1个token) + self.right_arm_decoder = RobotDecoder( + num_blocks=decoder_num_blocks, + input_dim=input_dim, + hidden_dim=hidden_dim, + output_dims= NUM_ACTIONS_CHUNK * self.action_dim, + mlp_type=mlp_type, + proj_type=proj_type, + ffn_type=ffn_type, + drop_ratio=drop_ratio, + without_action_projector=without_action_projector, + action_norm=action_norm, + num_experts=num_experts, + top_k=top_k, + expert_capacity_factor=expert_capacity_factor, + expansion_ratio=expansion_ratio, + num_shared_experts=num_shared_experts, + use_contrastive_loss=False, # 单独的decoder不需要contrastive loss + multi_query_norm_type='layernorm', # 每个arm内部使用标准LayerNorm + num_query=1 # 每个decoder只处理一个arm + ) + + def predict_action(self, actions_hidden_states, num_action_chunk=8): + """ + Args: + actions_hidden_states: (batch_size, 2, hidden_dim) + - [:, 0, :] 是左臂的token + - [:, 1, :] 是右臂的token + Returns: + actions: (batch_size, NUM_ACTIONS_CHUNK, action_dim * 2) # [batch, time, left_arm_actions + right_arm_actions] + """ + batch_size = actions_hidden_states.size(0) + assert actions_hidden_states.size(1) == 2, f"Expected 2 tokens for dual arms, got {actions_hidden_states.size(1)}" + + # 分离左臂和右臂的token + left_arm_token = actions_hidden_states[:, 0:1, :] # (B, 1, hidden_dim) - 左臂token + right_arm_token = actions_hidden_states[:, 1:2, :] # (B, 1, hidden_dim) - 右臂token + + # 分别通过各自的解码器 + left_actions = self.left_arm_decoder(left_arm_token) # (B, 1, action_dim * NUM_ACTIONS_CHUNK) + right_actions = self.right_arm_decoder(right_arm_token) # (B, 1, action_dim * NUM_ACTIONS_CHUNK) + + # 重新组织形状为时间序列 + left_actions = left_actions.reshape(batch_size, NUM_ACTIONS_CHUNK, self.action_dim) # (B, T, action_dim) + right_actions = right_actions.reshape(batch_size, NUM_ACTIONS_CHUNK, self.action_dim) # (B, T, action_dim) + + # 在最后一个维度上拼接左臂和右臂的动作:(batch_size, NUM_ACTIONS_CHUNK, action_dim * 2) + dual_arm_actions = torch.cat([left_actions, right_actions], dim=-1) + + if self.use_contrastive_loss: + # 如果需要对比学习,返回拼接的特征表示 + arm_features = torch.cat([left_arm_token, right_arm_token], dim=1) # (B, 2, hidden_dim) + return dual_arm_actions, arm_features + else: + return dual_arm_actions + + +""" +双臂机器人动作抖动问题解决方案: + +问题描述: +当actions shape为(b, l, n*d),其中l=2表示双臂时,在最后的linear映射层使用LayerNorm +可能导致双臂动作严重抖动。 + +原因分析: +1. LayerNorm在feature维度上进行归一化,破坏了两个arm之间动作的相对关系 +2. 双臂机器人需要保持arm间的协调性,LayerNorm可能破坏这种协调 +3. LayerNorm计算的统计量包含了两个arm的混合信息,影响了各自的独立性 + +推荐解决方案(按优先级排序): + +1. per_limb_layernorm / per_limb_rmsnorm(强烈推荐) + - 为每个肢体创建独立的归一化层 + - 完全避免肢体间的归一化干扰 + - 保持各肢体的独立性和协调性 + +2. rmsnorm(推荐) + - 相比LayerNorm更好地保持特征的相对关系 + - 减少对双臂协调性的破坏 + +3. layerscale(推荐) + - 轻量级可学习缩放 + - 保持双臂间的协调性,避免过度归一化 + +4. l2(推荐) + - 保持相对幅度关系 + - 不破坏双臂间的相对关系 + +5. none/identity(推荐) + - 完全不归一化,避免任何协调性破坏 + - 适用于模型已经稳定的情况 + +不推荐: +- layernorm: 可能导致双臂抖动 +- 复杂的自适应归一化:可能引入不稳定性 + +使用建议: +对于双臂机器人,优先尝试 per_limb_layernorm 或 per_limb_rmsnorm, +如果仍有问题,可以尝试 rmsnorm 或直接使用 none。 +""" + + +class TemporalConvActionHead(nn.Module): + """基于一维时序卷积网络 (TCN) 的动作解码器。 + + 该解码器专门设计用于处理来自VLM的单个聚合特征向量。它首先将该向量投影并 + 重塑为一个特征序列,然后利用多层膨胀卷积来显式建模动作的时间依赖关系, + 旨在生成更平滑、连贯的动作轨迹。 + """ + + def __init__( + self, + input_dim: int = 4096, + hidden_dim: int = 512, + action_dim: int = ACTION_DIM, + chunk_size: int = NUM_ACTIONS_CHUNK, + num_layers: int = 4, + kernel_size: int = 3, + dropout: float = 0.1, + ) -> None: + """ + 参数: + input_dim (int): VLM backbone 输出的特征维度。 + hidden_dim (int): TCN 内部的隐藏维度。 + action_dim (int): 单个时间步的机器人动作维度。 + chunk_size (int): 预测的动作序列长度。 + num_layers (int): TCN 的层数。 + kernel_size (int): 卷积核大小。 + dropout (float): TCN 内部的 dropout 概率。 + """ + super().__init__() + self.action_dim = action_dim + self.chunk_size = chunk_size + + # 步骤 1: 将输入的单个 embedding 投影成一个序列 + self.input_projection = nn.Sequential( + nn.Linear(input_dim, hidden_dim * chunk_size), + nn.GELU() + ) + + # 步骤 2: 构建 TCN 网络 + layers = [] + in_channels = hidden_dim + dilation = 1 + for _ in range(num_layers): + layers.append( + nn.Sequential( + nn.Conv1d( + in_channels, + hidden_dim, + kernel_size, + padding=(kernel_size - 1) * dilation // 2, # 'same' padding + dilation=dilation, + ), + nn.BatchNorm1d(hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + ) + ) + # TCN 不需要改变通道数,但可以根据设计调整 + # in_channels = hidden_dim + dilation *= 2 + self.tcn = nn.Sequential(*layers) + + # 步骤 3: 最终的输出层 + self.output_projection = nn.Linear(hidden_dim, action_dim) + + def predict_action(self, actions_hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """从单个特征向量预测整个动作序列。 + + Args: + actions_hidden_states (torch.Tensor): VLM 输出的聚合特征, + 形状为 (batch_size, 1, input_dim)。 + + Returns: + torch.Tensor: 预测的动作序列, 形状为 (batch_size, chunk_size, action_dim)。 + """ + B, N, _ = actions_hidden_states.shape + assert N == 1, f"TemporalConvActionHead expects a single feature vector, but received {N}." + + # (B, 1, D) -> (B, D) + x = actions_hidden_states.squeeze(1) + + # 投影到序列特征 (B, hidden_dim * chunk_size) + x = self.input_projection(x) + + # 重塑以适配 Conv1d: (B, C, L) -> (B, hidden_dim, chunk_size) + x = x.view(B, self.chunk_size, -1).permute(0, 2, 1) + + # 通过 TCN 进行时序建模 + x = self.tcn(x) # (B, hidden_dim, chunk_size) + + # 重新排列以适配线性层: (B, L, C) -> (B, chunk_size, hidden_dim) + x = x.permute(0, 2, 1) + + # 映射到最终动作空间 + actions = self.output_projection(x) # (B, chunk_size, action_dim) + + return actions + + +class KeyframeActionHead(nn.Module): + """ + Predicts actions for a sparse sequence of keyframe embeddings. + It processes each keyframe embedding independently to predict the corresponding action. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + action_dim: int, + decoder_num_blocks: int = 2, + mlp_type: str = "ffn", + proj_type:str = 'relu_linear', + drop_ratio:float = 0.0, + num_query:int = NUM_ACTIONS_CHUNK, + **kwargs, + ): + super().__init__() + self.action_dim = action_dim + + # This decoder will process each keyframe embedding in the sequence. + self.head = RobotDecoder( + num_blocks=decoder_num_blocks, + input_dim=input_dim, + hidden_dim=hidden_dim, + output_dims=action_dim, # Predict a single action per input embedding + mlp_type=mlp_type, + proj_type=proj_type, + drop_ratio=drop_ratio, + num_query=num_query, # Process each keyframe independently + **kwargs, + ) + + def predict_action(self, keyframe_hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Predicts actions for the given keyframe embeddings. + Args: + keyframe_hidden_states (torch.Tensor): A sparse sequence of hidden states, + of shape [B, num_keyframes, D_in]. + Returns: + torch.Tensor: Predicted actions for each keyframe, of shape [B, num_keyframes, D_action]. + """ + # The RobotDecoder will apply the transformation to each element in the sequence. + predicted_keyframes = self.head(keyframe_hidden_states) + return predicted_keyframes + + +def interpolate_keyframes(keyframes: torch.Tensor, final_len: int, smoothing_sigma: float) -> torch.Tensor: + """ + Interpolates keyframes to a final length and applies Gaussian smoothing. + Args: + keyframes (torch.Tensor): Tensor of shape [B, num_keyframes, D_action]. + final_len (int): The final length of the sequence (T). + smoothing_sigma (float): Sigma for Gaussian smoothing. A value of 0 disables smoothing. + Returns: + torch.Tensor: Interpolated and smoothed actions of shape [B, final_len, D_action]. + """ + # Transpose to (B, D_action, num_keyframes) for interpolation + keyframes_t = keyframes.transpose(1, 2) + + # Use torch.nn.functional.interpolate for batched linear interpolation + interpolated_actions_t = torch.nn.functional.interpolate( + keyframes_t, size=final_len, mode='linear', align_corners=True + ) + + # Apply Gaussian smoothing if sigma > 0 + if smoothing_sigma > 0: + # torchvision.transforms.functional.gaussian_blur expects a 4D tensor (B, C, H, W) + # We treat sequence length T as width W, and add a dummy height H=1. + interpolated_actions_4d = interpolated_actions_t.unsqueeze(2) + # Kernel size is derived from sigma, must be odd. + kernel_size = int(4 * smoothing_sigma + 0.5) * 2 + 1 + + smoothed_actions_4d = gaussian_blur( + interpolated_actions_4d, kernel_size=(1, kernel_size), sigma=(smoothing_sigma, smoothing_sigma) + ) + interpolated_actions_t = smoothed_actions_4d.squeeze(2) + + # Transpose back to (B, final_len, D_action) + interpolated_actions = interpolated_actions_t.transpose(1, 2) + return interpolated_actions \ No newline at end of file