cross13tasks / code /model /modules /action_model /GR00T_ActionHeader_moh.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
"""
Mixture of Horizons (MoH) version of GROOT ActionHeader.
This is a reference implementation showing how to integrate MoH strategy into GROOT.
Key changes:
1. Support multiple horizons (e.g., [5, 10, 15, 20, 50])
2. Parallel processing via batching (batch_size * num_horizons)
3. Gating network for ensemble
4. Multi-component loss (individual + auxiliary + load balancing)
"""
import torch
import torch.nn.functional as F
from torch import nn
from typing import Optional, List
from dataclasses import dataclass, field
from starVLA.model.modules.action_model.flow_matching_head.action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT
from starVLA.model.modules.action_model.GR00T_ActionHeader import (
FlowmatchingActionHeadConfig,
ActionEncoder,
MLP,
)
class FlowmatchingActionHeadMoH(nn.Module):
"""
GROOT ActionHeader with Mixture of Horizons support.
Key differences from original:
- Supports multiple horizons (e.g., [5, 10, 15, 20, 50])
- Processes all horizons in parallel via batching
- Uses gating network to ensemble predictions
- Multi-component loss function
"""
def __init__(
self,
full_config,
horizons: List[int] = [2,5,8], # Different horizon lengths
use_gate_noise: bool = True, # Add learnable noise to gate logits
):
super().__init__()
config = full_config.framework.action_model
self.horizons = sorted(horizons) # Ensure sorted
self.max_horizon = self.horizons[-1]
self.num_horizons = len(self.horizons)
self.use_gate_noise = use_gate_noise
self.hidden_size = config.hidden_size
self.full_config = full_config
action_model_type = config.action_model_type
action_model_cfg = {
"DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12},
"DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32},
}[action_model_type]
self.input_embedding_dim = action_model_cfg["input_embedding_dim"]
diffusion_model_cfg = config.diffusion_model_cfg
diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg}
self.model = DiT(**diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.future_action_window_size + 1
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = MLP(
input_dim=config.state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
) if config.state_dim else None
self.action_encoder = ActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
)
self.action_decoder = MLP(
input_dim=self.model.config.output_dim,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = torch.distributions.Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
# MoH-specific components
# Gating network: predicts weights for each horizon at each timestep
# Input: model output features, Output: gate logits for each horizon
self.gate_out_proj = nn.Linear(self.model.config.output_dim, 1)
if self.use_gate_noise:
self.gate_noise_layer = nn.Linear(self.model.config.output_dim, 1)
self.softplus = nn.Softplus()
print(f"MoH ActionHeader initialized with horizons: {self.horizons}")
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype).clamp(max=self.config.noise_s)
return (self.config.noise_s - sample) / self.config.noise_s
def cv_squared(self, x):
"""Coefficient of variation squared for load balancing."""
eps = 1e-10
if x.shape[0] == 1:
return torch.tensor(0.0, device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean() ** 2 + eps)
def forward(
self,
vl_embs: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor = None,
encoder_attention_mask=None,
loss_config: dict = None
):
"""
Forward pass with MoH strategy.
Args:
vl_embs: (B, seq_length, feature_dim) - Vision-language embeddings
actions: (B, max_horizon, D_action) - Ground truth actions (padded to max_horizon)
state: (B, state_dim) - Optional state features
encoder_attention_mask: Attention mask for encoder
loss_config: Dict with 'aux_weight' and 'balance_weight'
Returns:
total_loss: Combined loss from all components
"""
device = vl_embs.device
batch_size = actions.shape[0]
num_horizons = len(self.horizons)
max_horizon = self.max_horizon
# Sample noise and time
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
time_scalar = self.sample_time(batch_size, device, actions.dtype)
# Expand time for each horizon: (num_h, batch_size)
time = time_scalar.unsqueeze(0).expand(num_horizons, -1)
# x_t: (num_h, batch_size, max_horizon, action_dim)
# Flow matching: x_t = (1-t) * noise + t * actions, where t=0 is noise and t=1 is actions
# Expand noise and actions to (num_h, batch_size, max_horizon, action_dim)
noise_expanded = noise.unsqueeze(0).expand(num_horizons, -1, -1, -1)
actions_expanded = actions.unsqueeze(0).expand(num_horizons, -1, -1, -1)
t_expanded = time[:, :, None, None] # (num_h, batch_size, 1, 1)
x_t = (1 - t_expanded) * noise_expanded + t_expanded * actions_expanded
# u_t (target velocity): (batch_size, max_horizon, action_dim)
u_t = actions - noise
# ============================================================
# STAGE 1: Prepare inputs for parallel processing
# ============================================================
# Repeat vl_embs and state for each horizon
batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0) # (B*H, seq_len, dim)
batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None
if encoder_attention_mask is not None:
batched_encoder_attention_mask = encoder_attention_mask.repeat_interleave(num_horizons, dim=0)
else:
batched_encoder_attention_mask = None
# Reshape x_t and time for batched processing
# x_t: (num_h, batch_size, max_horizon, dim) -> (batch_size * num_h, max_horizon, dim)
batched_x_t = x_t.permute(1, 0, 2, 3).reshape(batch_size * num_horizons, max_horizon, -1)
# time: (num_h, batch_size) -> (batch_size * num_h)
batched_time = time.permute(1, 0).reshape(batch_size * num_horizons)
# Create padding masks for each horizon
# action_pad_mask: (num_h, max_horizon) - True where valid, False where padding
action_pad_mask = torch.arange(max_horizon, device=device)[None, :] < \
torch.tensor(self.horizons, device=device)[:, None]
# Expand to batch: (num_h, batch_size, max_horizon)
action_pad_mask = action_pad_mask.unsqueeze(1).expand(-1, batch_size, -1)
# Reshape: (batch_size * num_h, max_horizon)
batched_action_pad_mask = action_pad_mask.permute(1, 0, 2).reshape(batch_size * num_horizons, max_horizon)
# ============================================================
# STAGE 2: Forward pass through model (parallel for all horizons)
# ============================================================
# Convert time to discrete timesteps
t_discretized = (batched_time * self.num_timestep_buckets).long()
# Encode actions
action_features = self.action_encoder(batched_x_t, t_discretized)
# Add position embedding if needed
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Prepare state and action embeddings
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(
batch_size * num_horizons, -1, -1
)
if batched_state is not None:
state_features = self.state_encoder(batched_state)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
else:
sa_embs = torch.cat((future_tokens, action_features), dim=1)
# Forward through DiT model
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=batched_vl_embs,
encoder_attention_mask=batched_encoder_attention_mask,
timestep=t_discretized,
return_all_hidden_states=False,
)
# Decode actions
pred = self.action_decoder(model_output)
# Extract action predictions (last max_horizon tokens)
# pred: (B*H, seq_len, action_dim) -> (B*H, max_horizon, action_dim)
state_offset = 1 if state is not None else 0
future_tokens_len = self.future_tokens.num_embeddings
action_start_idx = state_offset + future_tokens_len
pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :]
# Reshape to separate predictions for each horizon
# (B*H, max_horizon, dim) -> (B, H, max_horizon, dim) -> (H, B, max_horizon, dim)
all_v_t_preds = pred_actions_padded.view(
batch_size, num_horizons, max_horizon, -1
).permute(1, 0, 2, 3)
# ============================================================
# STAGE 3: Compute losses
# ============================================================
# 1. Individual loss: Each horizon's prediction vs target
all_head_losses = []
for i, h in enumerate(self.horizons):
v_t_head = all_v_t_preds[i, :, :h, :] # (B, h, dim)
target_v_t = u_t[:, :h, :] # (B, h, dim)
head_loss = F.mse_loss(v_t_head, target_v_t)
all_head_losses.append(head_loss)
individual_loss = torch.sum(torch.stack(all_head_losses))
# 2. Gating network: Generate weights for ensemble
gate_logits = self.gate_out_proj(model_output.to(torch.float32))
gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :] # (B*H, max_horizon, 1)
if self.use_gate_noise:
# Add learnable noise to gate logits
noise_epsilon = 1e-2
raw_noise_stddev = self.gate_noise_layer(model_output.to(torch.float32))
raw_noise_stddev = raw_noise_stddev[:, action_start_idx:action_start_idx + max_horizon, :]
noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
gate_logits = gate_logits + (torch.randn_like(gate_logits) * noise_stddev)
# Reshape gate logits: (B*H, max_horizon, 1) -> (B, H, max_horizon) -> (B, max_horizon, H)
gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1)
# Apply mask: invalid horizons (where step >= horizon) get -inf
valid_heads_mask = torch.tensor(
[[step < h for h in self.horizons] for step in range(max_horizon)],
device=device, dtype=torch.bool
).unsqueeze(0) # (1, max_horizon, H)
masked_gate_logits = torch.where(
valid_heads_mask,
gate_logits,
torch.finfo(gate_logits.dtype).min
)
gate_weights = F.softmax(masked_gate_logits, dim=-1) # (B, max_horizon, H)
# 3. Ensemble predictions using gate weights
# all_v_t_preds: (H, B, max_horizon, dim) -> (B, H, max_horizon, dim)
all_v_t_preds_padded = all_v_t_preds.permute(1, 0, 2, 3)
# gate_weights: (B, max_horizon, H) -> (B, H, max_horizon, 1)
# combined: (B, max_horizon, dim)
v_t_combined = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1)
# Auxiliary loss: Ensemble prediction vs target
aux_loss_weight = loss_config.get("aux_weight", 1.0) if loss_config else 1.0
auxiliary_loss = F.mse_loss(v_t_combined, u_t)
# 4. Load balancing loss: Encourage balanced usage of horizons
loss_components = []
boundaries = sorted(list(set([0] + self.horizons)))
for i in range(len(boundaries) - 1):
start_step, end_step = boundaries[i], boundaries[i + 1]
active_expert_indices = [idx for idx, h in enumerate(self.horizons) if h > start_step]
if len(active_expert_indices) > 1:
segment_gate_weights = gate_weights[:, start_step:end_step, :]
active_expert_weights = segment_gate_weights[:, :, active_expert_indices]
avg_expert_prob_in_segment = active_expert_weights.mean(dim=(0, 1))
segment_loss = self.cv_squared(avg_expert_prob_in_segment)
loss_components.append(segment_loss)
load_balancing_loss = torch.mean(torch.stack(loss_components)) if loss_components else torch.tensor(0.0, device=device)
balance_loss_weight = loss_config.get("balance_weight", 0.001) if loss_config else 0.001
# Total loss
total_loss = individual_loss + aux_loss_weight * auxiliary_loss + balance_loss_weight * load_balancing_loss
return total_loss
@torch.no_grad()
def predict_action(
self,
vl_embs: torch.Tensor,
state: torch.Tensor = None,
ret_weights: bool = False
) -> dict:
"""
Inference with MoH ensemble.
Args:
vl_embs: (B, seq_length, feature_dim)
state: (B, state_dim) - Optional
ret_weights: Whether to return gate weights
Returns:
dict with 'actions' and optionally 'gate_weights'
"""
batch_size = vl_embs.shape[0]
device = vl_embs.device
num_horizons = len(self.horizons)
max_horizon = self.max_horizon
# Initialize actions as noise
actions = torch.randn(
size=(batch_size, max_horizon, self.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
gate_weights_to_log = []
# Prepare batched inputs (same for all denoising steps)
batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0)
batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None
# Denoising loop
for t in range(num_steps):
t_cont = t / float(num_steps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size * num_horizons,),
fill_value=t_discretized,
device=device
)
# Prepare padded actions for each horizon
padded_x_t_list, action_pad_mask_list = [], []
for h in self.horizons:
padded_x_t = F.pad(actions[:, :h, :], (0, 0, 0, max_horizon - h))
padded_x_t_list.append(padded_x_t)
pad_mask = F.pad(
torch.ones((batch_size, h), device=device, dtype=torch.bool),
(0, max_horizon - h),
value=False
)
action_pad_mask_list.append(pad_mask)
batched_x_t = torch.cat(padded_x_t_list, dim=0)
action_pad_mask = torch.cat(action_pad_mask_list, dim=0)
# Encode actions
action_features = self.action_encoder(batched_x_t, timesteps_tensor)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Prepare embeddings
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(
batch_size * num_horizons, -1, -1
)
if batched_state is not None:
state_features = self.state_encoder(batched_state)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
else:
sa_embs = torch.cat((future_tokens, action_features), dim=1)
# Forward through model
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=batched_vl_embs,
timestep=timesteps_tensor,
)
# Decode actions
pred = self.action_decoder(model_output)
# Extract action predictions
state_offset = 1 if state is not None else 0
future_tokens_len = self.future_tokens.num_embeddings
action_start_idx = state_offset + future_tokens_len
pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :]
# Reshape: (B*H, max_horizon, dim) -> (B, H, max_horizon, dim)
all_v_t_preds_padded = pred_actions_padded.view(
num_horizons, batch_size, max_horizon, -1
).permute(1, 0, 2, 3)
# Gating network
gate_logits = self.gate_out_proj(model_output.to(torch.float32))
gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :]
gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1)
valid_heads_mask = torch.tensor(
[[step < h for h in self.horizons] for step in range(max_horizon)],
device=device, dtype=torch.bool
).unsqueeze(0)
masked_gate_logits = torch.where(
valid_heads_mask,
gate_logits,
torch.finfo(gate_logits.dtype).min
)
gate_weights = F.softmax(masked_gate_logits, dim=-1)
if ret_weights:
gate_weights_to_log.append(torch.round(gate_weights, decimals=3))
# Ensemble predictions
v_t = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1)
# Euler update
actions = actions + dt * v_t
return_dict = {"actions": actions}
if ret_weights and len(gate_weights_to_log) > 0:
return_dict["gate_weights"] = torch.stack(gate_weights_to_log, dim=1).detach().cpu()
return return_dict["actions"]
def get_action_model(config=None, horizons: List[int] = [2,5,8]):
"""
Factory: build FlowmatchingActionHeadMoH from global framework config.
Args:
config: Global config (expects config.framework.action_model namespace).
horizons: List of horizon lengths to use for MoH
Returns:
FlowmatchingActionHeadMoH: Initialized MoH ActionHeader.
"""
return FlowmatchingActionHeadMoH(
full_config=config,
horizons=horizons,
)