| """ |
| Mixture of Horizons (MoH) version of GROOT ActionHeader. |
| |
| This is a reference implementation showing how to integrate MoH strategy into GROOT. |
| Key changes: |
| 1. Support multiple horizons (e.g., [5, 10, 15, 20, 50]) |
| 2. Parallel processing via batching (batch_size * num_horizons) |
| 3. Gating network for ensemble |
| 4. Multi-component loss (individual + auxiliary + load balancing) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from typing import Optional, List |
| from dataclasses import dataclass, field |
|
|
| from starVLA.model.modules.action_model.flow_matching_head.action_encoder import ( |
| SinusoidalPositionalEncoding, |
| swish, |
| ) |
| from starVLA.model.modules.action_model.flow_matching_head.cross_attention_dit import DiT |
| from starVLA.model.modules.action_model.GR00T_ActionHeader import ( |
| FlowmatchingActionHeadConfig, |
| ActionEncoder, |
| MLP, |
| ) |
|
|
|
|
| class FlowmatchingActionHeadMoH(nn.Module): |
| """ |
| GROOT ActionHeader with Mixture of Horizons support. |
| |
| Key differences from original: |
| - Supports multiple horizons (e.g., [5, 10, 15, 20, 50]) |
| - Processes all horizons in parallel via batching |
| - Uses gating network to ensemble predictions |
| - Multi-component loss function |
| """ |
| |
| def __init__( |
| self, |
| full_config, |
| horizons: List[int] = [2,5,8], |
| use_gate_noise: bool = True, |
| ): |
| super().__init__() |
| config = full_config.framework.action_model |
| self.horizons = sorted(horizons) |
| self.max_horizon = self.horizons[-1] |
| self.num_horizons = len(self.horizons) |
| self.use_gate_noise = use_gate_noise |
| |
| self.hidden_size = config.hidden_size |
| self.full_config = full_config |
| action_model_type = config.action_model_type |
| action_model_cfg = { |
| "DiT-B": {"input_embedding_dim": 768, "attention_head_dim": 64, "num_attention_heads": 12}, |
| "DiT-L": {"input_embedding_dim": 1536, "attention_head_dim": 48, "num_attention_heads": 32}, |
| }[action_model_type] |
| |
| self.input_embedding_dim = action_model_cfg["input_embedding_dim"] |
| diffusion_model_cfg = config.diffusion_model_cfg |
| diffusion_model_cfg = {**action_model_cfg, **diffusion_model_cfg} |
| self.model = DiT(**diffusion_model_cfg) |
| self.action_dim = config.action_dim |
| self.action_horizon = config.future_action_window_size + 1 |
| self.num_inference_timesteps = config.num_inference_timesteps |
|
|
| self.state_encoder = MLP( |
| input_dim=config.state_dim, |
| hidden_dim=self.hidden_size, |
| output_dim=self.input_embedding_dim, |
| ) if config.state_dim else None |
|
|
| self.action_encoder = ActionEncoder( |
| action_dim=config.action_dim, |
| hidden_size=self.input_embedding_dim, |
| ) |
| self.action_decoder = MLP( |
| input_dim=self.model.config.output_dim, |
| hidden_dim=self.hidden_size, |
| output_dim=self.action_dim, |
| ) |
| self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim) |
| nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02) |
|
|
| if config.add_pos_embed: |
| self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) |
| nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) |
|
|
| self.beta_dist = torch.distributions.Beta(config.noise_beta_alpha, config.noise_beta_beta) |
| self.num_timestep_buckets = config.num_timestep_buckets |
| self.config = config |
| |
| |
| |
| |
| self.gate_out_proj = nn.Linear(self.model.config.output_dim, 1) |
| if self.use_gate_noise: |
| self.gate_noise_layer = nn.Linear(self.model.config.output_dim, 1) |
| self.softplus = nn.Softplus() |
| |
| print(f"MoH ActionHeader initialized with horizons: {self.horizons}") |
|
|
| def sample_time(self, batch_size, device, dtype): |
| sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype).clamp(max=self.config.noise_s) |
| return (self.config.noise_s - sample) / self.config.noise_s |
|
|
| def cv_squared(self, x): |
| """Coefficient of variation squared for load balancing.""" |
| eps = 1e-10 |
| if x.shape[0] == 1: |
| return torch.tensor(0.0, device=x.device, dtype=x.dtype) |
| return x.float().var() / (x.float().mean() ** 2 + eps) |
|
|
| def forward( |
| self, |
| vl_embs: torch.Tensor, |
| actions: torch.Tensor, |
| state: torch.Tensor = None, |
| encoder_attention_mask=None, |
| loss_config: dict = None |
| ): |
| """ |
| Forward pass with MoH strategy. |
| |
| Args: |
| vl_embs: (B, seq_length, feature_dim) - Vision-language embeddings |
| actions: (B, max_horizon, D_action) - Ground truth actions (padded to max_horizon) |
| state: (B, state_dim) - Optional state features |
| encoder_attention_mask: Attention mask for encoder |
| loss_config: Dict with 'aux_weight' and 'balance_weight' |
| |
| Returns: |
| total_loss: Combined loss from all components |
| """ |
| device = vl_embs.device |
| batch_size = actions.shape[0] |
| num_horizons = len(self.horizons) |
| max_horizon = self.max_horizon |
| |
| |
| noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) |
| time_scalar = self.sample_time(batch_size, device, actions.dtype) |
| |
| |
| time = time_scalar.unsqueeze(0).expand(num_horizons, -1) |
| |
| |
| |
| noise_expanded = noise.unsqueeze(0).expand(num_horizons, -1, -1, -1) |
| actions_expanded = actions.unsqueeze(0).expand(num_horizons, -1, -1, -1) |
| t_expanded = time[:, :, None, None] |
| x_t = (1 - t_expanded) * noise_expanded + t_expanded * actions_expanded |
| |
| u_t = actions - noise |
| |
| |
| |
| |
| |
| batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0) |
| batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None |
| if encoder_attention_mask is not None: |
| batched_encoder_attention_mask = encoder_attention_mask.repeat_interleave(num_horizons, dim=0) |
| else: |
| batched_encoder_attention_mask = None |
| |
| |
| |
| batched_x_t = x_t.permute(1, 0, 2, 3).reshape(batch_size * num_horizons, max_horizon, -1) |
| |
| batched_time = time.permute(1, 0).reshape(batch_size * num_horizons) |
| |
| |
| |
| action_pad_mask = torch.arange(max_horizon, device=device)[None, :] < \ |
| torch.tensor(self.horizons, device=device)[:, None] |
| |
| action_pad_mask = action_pad_mask.unsqueeze(1).expand(-1, batch_size, -1) |
| |
| batched_action_pad_mask = action_pad_mask.permute(1, 0, 2).reshape(batch_size * num_horizons, max_horizon) |
| |
| |
| |
| |
| |
| t_discretized = (batched_time * self.num_timestep_buckets).long() |
| |
| |
| action_features = self.action_encoder(batched_x_t, t_discretized) |
| |
| |
| if self.config.add_pos_embed: |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) |
| action_features = action_features + pos_embs |
| |
| |
| future_tokens = self.future_tokens.weight.unsqueeze(0).expand( |
| batch_size * num_horizons, -1, -1 |
| ) |
| if batched_state is not None: |
| state_features = self.state_encoder(batched_state) |
| sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) |
| else: |
| sa_embs = torch.cat((future_tokens, action_features), dim=1) |
| |
| |
| model_output = self.model( |
| hidden_states=sa_embs, |
| encoder_hidden_states=batched_vl_embs, |
| encoder_attention_mask=batched_encoder_attention_mask, |
| timestep=t_discretized, |
| return_all_hidden_states=False, |
| ) |
| |
| |
| pred = self.action_decoder(model_output) |
| |
| |
| |
| state_offset = 1 if state is not None else 0 |
| future_tokens_len = self.future_tokens.num_embeddings |
| action_start_idx = state_offset + future_tokens_len |
| pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :] |
| |
| |
| |
| all_v_t_preds = pred_actions_padded.view( |
| batch_size, num_horizons, max_horizon, -1 |
| ).permute(1, 0, 2, 3) |
| |
| |
| |
| |
| |
| all_head_losses = [] |
| for i, h in enumerate(self.horizons): |
| v_t_head = all_v_t_preds[i, :, :h, :] |
| target_v_t = u_t[:, :h, :] |
| head_loss = F.mse_loss(v_t_head, target_v_t) |
| all_head_losses.append(head_loss) |
| individual_loss = torch.sum(torch.stack(all_head_losses)) |
| |
| |
| gate_logits = self.gate_out_proj(model_output.to(torch.float32)) |
| gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :] |
| |
| if self.use_gate_noise: |
| |
| noise_epsilon = 1e-2 |
| raw_noise_stddev = self.gate_noise_layer(model_output.to(torch.float32)) |
| raw_noise_stddev = raw_noise_stddev[:, action_start_idx:action_start_idx + max_horizon, :] |
| noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon |
| gate_logits = gate_logits + (torch.randn_like(gate_logits) * noise_stddev) |
| |
| |
| gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1) |
| |
| |
| valid_heads_mask = torch.tensor( |
| [[step < h for h in self.horizons] for step in range(max_horizon)], |
| device=device, dtype=torch.bool |
| ).unsqueeze(0) |
| |
| masked_gate_logits = torch.where( |
| valid_heads_mask, |
| gate_logits, |
| torch.finfo(gate_logits.dtype).min |
| ) |
| gate_weights = F.softmax(masked_gate_logits, dim=-1) |
| |
| |
| |
| all_v_t_preds_padded = all_v_t_preds.permute(1, 0, 2, 3) |
| |
| |
| v_t_combined = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1) |
| |
| |
| aux_loss_weight = loss_config.get("aux_weight", 1.0) if loss_config else 1.0 |
| auxiliary_loss = F.mse_loss(v_t_combined, u_t) |
| |
| |
| loss_components = [] |
| boundaries = sorted(list(set([0] + self.horizons))) |
| for i in range(len(boundaries) - 1): |
| start_step, end_step = boundaries[i], boundaries[i + 1] |
| active_expert_indices = [idx for idx, h in enumerate(self.horizons) if h > start_step] |
| |
| if len(active_expert_indices) > 1: |
| segment_gate_weights = gate_weights[:, start_step:end_step, :] |
| active_expert_weights = segment_gate_weights[:, :, active_expert_indices] |
| avg_expert_prob_in_segment = active_expert_weights.mean(dim=(0, 1)) |
| segment_loss = self.cv_squared(avg_expert_prob_in_segment) |
| loss_components.append(segment_loss) |
| |
| load_balancing_loss = torch.mean(torch.stack(loss_components)) if loss_components else torch.tensor(0.0, device=device) |
| balance_loss_weight = loss_config.get("balance_weight", 0.001) if loss_config else 0.001 |
| |
| |
| total_loss = individual_loss + aux_loss_weight * auxiliary_loss + balance_loss_weight * load_balancing_loss |
| |
| return total_loss |
|
|
| @torch.no_grad() |
| def predict_action( |
| self, |
| vl_embs: torch.Tensor, |
| state: torch.Tensor = None, |
| ret_weights: bool = False |
| ) -> dict: |
| """ |
| Inference with MoH ensemble. |
| |
| Args: |
| vl_embs: (B, seq_length, feature_dim) |
| state: (B, state_dim) - Optional |
| ret_weights: Whether to return gate weights |
| |
| Returns: |
| dict with 'actions' and optionally 'gate_weights' |
| """ |
| batch_size = vl_embs.shape[0] |
| device = vl_embs.device |
| num_horizons = len(self.horizons) |
| max_horizon = self.max_horizon |
| |
| |
| actions = torch.randn( |
| size=(batch_size, max_horizon, self.action_dim), |
| dtype=vl_embs.dtype, |
| device=device, |
| ) |
| |
| num_steps = self.num_inference_timesteps |
| dt = 1.0 / num_steps |
| |
| gate_weights_to_log = [] |
| |
| |
| batched_vl_embs = vl_embs.repeat_interleave(num_horizons, dim=0) |
| batched_state = state.repeat_interleave(num_horizons, dim=0) if state is not None else None |
| |
| |
| for t in range(num_steps): |
| t_cont = t / float(num_steps) |
| t_discretized = int(t_cont * self.num_timestep_buckets) |
| timesteps_tensor = torch.full( |
| size=(batch_size * num_horizons,), |
| fill_value=t_discretized, |
| device=device |
| ) |
| |
| |
| padded_x_t_list, action_pad_mask_list = [], [] |
| for h in self.horizons: |
| padded_x_t = F.pad(actions[:, :h, :], (0, 0, 0, max_horizon - h)) |
| padded_x_t_list.append(padded_x_t) |
| pad_mask = F.pad( |
| torch.ones((batch_size, h), device=device, dtype=torch.bool), |
| (0, max_horizon - h), |
| value=False |
| ) |
| action_pad_mask_list.append(pad_mask) |
| |
| batched_x_t = torch.cat(padded_x_t_list, dim=0) |
| action_pad_mask = torch.cat(action_pad_mask_list, dim=0) |
| |
| |
| action_features = self.action_encoder(batched_x_t, timesteps_tensor) |
| |
| if self.config.add_pos_embed: |
| pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) |
| pos_embs = self.position_embedding(pos_ids).unsqueeze(0) |
| action_features = action_features + pos_embs |
| |
| |
| future_tokens = self.future_tokens.weight.unsqueeze(0).expand( |
| batch_size * num_horizons, -1, -1 |
| ) |
| if batched_state is not None: |
| state_features = self.state_encoder(batched_state) |
| sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1) |
| else: |
| sa_embs = torch.cat((future_tokens, action_features), dim=1) |
| |
| |
| model_output = self.model( |
| hidden_states=sa_embs, |
| encoder_hidden_states=batched_vl_embs, |
| timestep=timesteps_tensor, |
| ) |
| |
| |
| pred = self.action_decoder(model_output) |
| |
| |
| state_offset = 1 if state is not None else 0 |
| future_tokens_len = self.future_tokens.num_embeddings |
| action_start_idx = state_offset + future_tokens_len |
| pred_actions_padded = pred[:, action_start_idx:action_start_idx + max_horizon, :] |
| |
| |
| all_v_t_preds_padded = pred_actions_padded.view( |
| num_horizons, batch_size, max_horizon, -1 |
| ).permute(1, 0, 2, 3) |
| |
| |
| gate_logits = self.gate_out_proj(model_output.to(torch.float32)) |
| gate_logits = gate_logits[:, action_start_idx:action_start_idx + max_horizon, :] |
| gate_logits = gate_logits.reshape(batch_size, num_horizons, max_horizon).permute(0, 2, 1) |
| |
| valid_heads_mask = torch.tensor( |
| [[step < h for h in self.horizons] for step in range(max_horizon)], |
| device=device, dtype=torch.bool |
| ).unsqueeze(0) |
| |
| masked_gate_logits = torch.where( |
| valid_heads_mask, |
| gate_logits, |
| torch.finfo(gate_logits.dtype).min |
| ) |
| gate_weights = F.softmax(masked_gate_logits, dim=-1) |
| |
| if ret_weights: |
| gate_weights_to_log.append(torch.round(gate_weights, decimals=3)) |
| |
| |
| v_t = (gate_weights.permute(0, 2, 1).unsqueeze(-1) * all_v_t_preds_padded).sum(dim=1) |
| |
| |
| actions = actions + dt * v_t |
| |
| return_dict = {"actions": actions} |
| if ret_weights and len(gate_weights_to_log) > 0: |
| return_dict["gate_weights"] = torch.stack(gate_weights_to_log, dim=1).detach().cpu() |
| |
| return return_dict["actions"] |
|
|
|
|
| def get_action_model(config=None, horizons: List[int] = [2,5,8]): |
| """ |
| Factory: build FlowmatchingActionHeadMoH from global framework config. |
| |
| Args: |
| config: Global config (expects config.framework.action_model namespace). |
| horizons: List of horizon lengths to use for MoH |
| |
| Returns: |
| FlowmatchingActionHeadMoH: Initialized MoH ActionHeader. |
| """ |
| return FlowmatchingActionHeadMoH( |
| full_config=config, |
| horizons=horizons, |
| ) |
|
|