| |
| |
| |
| """ |
| Qwen-GROOT Framework |
| A lightweight implementation that Qwen2.5-vl + Flow-matching head to directly predict continuous actions |
| Flow-matching header is copyright from GR00T N1.5, but a sample MoE inspired by PI_0 |
| """ |
| import sys |
| sys.path.append("/mnt/data/fangyu/code/rewardmodel") |
| from typing import List |
| from tqdm import tqdm |
| from typing import List, Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import copy |
| from starVLA.training.trainer_utils import initialize_overwatch |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
| from transformers import AutoImageProcessor, AutoModel |
| from omegaconf import OmegaConf |
|
|
| logger = initialize_overwatch(__name__) |
|
|
| |
| IGNORE_INDEX = -100 |
|
|
| from starVLA.model.framework.base_framework import baseframework |
| from starVLA.model.modules.vlm import get_vlm_model |
| from starVLA.model.modules.action_model.ActionModel_FM import ActionModelFM |
| from starVLA.model.modules.action_model.configuration_actionmodel import ActionModelConfig |
| from starVLA.dataloader.gr00t_lerobot.datasets import ACTION_REPRESENTATION_SLICES |
| from starVLA.training.trainer_utils.trainer_tools import resize_images |
| from starVLA.model.tools import FRAMEWORK_REGISTRY |
|
|
|
|
| |
| |
| |
|
|
| @FRAMEWORK_REGISTRY.register("QwenLatent") |
| class QwenLatent(baseframework): |
| """ |
| Multimodal vision-language-action model. |
| |
| Components: |
| - Qwen2.5 VL interface for fused language/vision token embeddings |
| - Layer-wise cross DiT diffusion head |
| |
| |
| Focus: Predict future continuous actions conditioned on images + instruction. |
| """ |
|
|
| @staticmethod |
| def _get_last_nonpad_indices(attention_mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Return the index of the last non-padding token for each sequence. |
| |
| Works for both tokenizer.padding_side == "left" and "right". |
| attention_mask: [B, T] with 1/True for real tokens and 0/False for pads. |
| """ |
| if attention_mask is None: |
| raise ValueError("attention_mask cannot be None") |
| if attention_mask.dim() != 2: |
| raise ValueError(f"attention_mask must be 2D [B,T], got shape {tuple(attention_mask.shape)}") |
|
|
| |
| |
| |
| |
| mask = attention_mask.to(dtype=torch.long) |
| rev_first_one = torch.flip(mask, dims=[1]).argmax(dim=1) |
| last_nonpad = mask.size(1) - 1 - rev_first_one |
| return last_nonpad |
|
|
| |
| def __init__( |
| self, |
| config: Optional[dict] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Construct all submodules and cache key configuration values. |
| |
| Args: |
| config: Hierarchical configuration (OmegaConf/dict) containing framework + trainer sections. |
| **kwargs: Reserved for future overrides (unused). |
| """ |
|
|
| super().__init__() |
| self.config = config |
| self.qwen_vl_interface = get_vlm_model(config=self.config) |
|
|
| |
| num_vl_layers, llm_hidden_size = 36, self.qwen_vl_interface.model.config.hidden_size |
| self.llm_hidden_size = llm_hidden_size |
| self.config.framework.qwenvl.vl_hidden_dim = llm_hidden_size |
| self.config.framework.qwenvl.num_vl_layers = num_vl_layers |
|
|
| action_model_cfg = getattr(self.config.framework, "action_model", None) |
| if action_model_cfg is not None: |
| action_model_kwargs = OmegaConf.to_container(action_model_cfg, resolve=True) |
| print(f"{action_model_kwargs=}") |
| self.action_model = ActionModelFM(ActionModelConfig(**action_model_kwargs)) |
| else: |
| self.action_model = ActionModelFM(ActionModelConfig()) |
| ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None) |
| if ckpt_path: |
| self.action_model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=True) |
| print(f"✅ loaded action model from {ckpt_path}") |
| print(f"action model loss mode: {self.action_model.config.loss_mode}") |
| |
| self.dataset_vocab_size = getattr(self.config.framework.action_model, "dataset_vocab_size", 256) |
| self.num_data_tokens = getattr(self.config.framework.qwenvl, "num_data_tokens", 32) |
| self.dataset_embed = nn.Embedding( |
| self.dataset_vocab_size, |
| llm_hidden_size * self.num_data_tokens, |
| ) |
| |
| self.query_token = nn.Parameter(torch.randn(1, 1, llm_hidden_size)) |
|
|
| |
| action_hidden_size = self.action_model.config.hidden_size |
| self.action_embed_projector = nn.Sequential( |
| nn.Linear(llm_hidden_size, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, action_hidden_size), |
| ) |
|
|
| self.chunk_size = self.config.datasets.vla_data.chunk_size |
| self.num_history_steps = 0 |
| self.use_state = self.action_model.use_state |
| |
| def _maybe_log_align_stats( |
| self, |
| predicted_action_embeddings: torch.Tensor, |
| gt_action_embeddings: torch.Tensor, |
| ) -> None: |
| if getattr(self, "_align_stats_logged", False): |
| return |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| if torch.distributed.get_rank() != 0: |
| return |
| with torch.no_grad(): |
| pred = predicted_action_embeddings.float() |
| gt = gt_action_embeddings.float() |
| pred_norm = pred.norm(dim=-1).mean().item() |
| gt_norm = gt.norm(dim=-1).mean().item() |
| logger.info( |
| "Align stats: pred(mean=%.4f,std=%.4f,avg_norm=%.4f) " |
| "gt(mean=%.4f,std=%.4f,avg_norm=%.4f)", |
| pred.mean().item(), |
| pred.std().item(), |
| pred_norm, |
| gt.mean().item(), |
| gt.std().item(), |
| gt_norm, |
| ) |
| self._align_stats_logged = True |
|
|
| def forward( |
| self, |
| examples: List[dict] = None, |
| **kwargs, |
| ): |
| """ |
| Args: |
| examples: List[dict], each dict requires: |
| - image: List[PIL.Image] (multi-view) |
| - lang: str instruction |
| - action: np.ndarray or list shaped [T, action_dim] |
| Returns: |
| dict: |
| action_loss (torch.Tensor): Scalar diffusion noise prediction loss. |
| """ |
| batch_images = [example["image"] for example in examples] |
| instructions = [example["lang"] for example in examples] |
| actions = [example["action"] for example in examples] |
| states = [example["state"] for example in examples] if self.use_state else None |
| dataset_ids = [example.get("dataset_id", 0) for example in examples] |
|
|
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions, |
| chunk_size=self.chunk_size, |
| ) |
| |
| if "input_ids" in qwen_inputs: |
| dataset_ids_tensor = torch.tensor( |
| dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long |
| ) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| len(dataset_ids), self.num_data_tokens, self.llm_hidden_size |
| ) |
| token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) |
| query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) |
| qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1) |
| qwen_inputs.pop("input_ids") |
| if "attention_mask" in qwen_inputs: |
| prefix_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| query_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], 1), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| qwen_inputs["attention_mask"] = torch.cat( |
| (prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1 |
| ) |
| if "position_ids" in qwen_inputs: |
| prefix_pos = torch.arange( |
| self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) |
| query_pos = ( |
| torch.full( |
| (qwen_inputs["position_ids"].shape[0], 1), |
| qwen_inputs["position_ids"].shape[1] + self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ) |
| ) |
| qwen_inputs["position_ids"] = torch.cat( |
| (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1 |
| ) |
|
|
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| last_hidden_states = qwenvl_outputs.hidden_states[-1] |
|
|
| if "attention_mask" in qwen_inputs: |
| |
| last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"]) |
| batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device) |
| action_token_hidden = last_hidden_states[batch_indices, last_token_indices] |
| else: |
| action_token_hidden = last_hidden_states[:, -1, :] |
|
|
| predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() |
| predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1) |
|
|
| |
| loss_mode = getattr(self.action_model.config, "loss_mode", "full") |
|
|
| with torch.autocast("cuda", dtype=torch.float32): |
| actions_target = torch.as_tensor(np.array(actions), device=last_hidden_states.device, dtype=torch.float32) |
|
|
| B = actions_target.shape[0] |
| t = self.action_model._sample_fm_time(B, device=actions_target.device, dtype=actions_target.dtype) |
| noise = torch.randn_like(actions_target) |
|
|
| if loss_mode == "predict_only": |
| |
| align_loss = None |
| recon_loss = None |
| predict_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=predicted_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
| else: |
| |
| |
| states_target = None |
| if self.use_state: |
| states_target = torch.as_tensor(np.array(states), device=last_hidden_states.device, dtype=torch.float32) |
|
|
| gt_action_embeddings = self.action_model.encode_actions( |
| actions=actions_target, |
| dataset_ids=dataset_ids, |
| state=states_target, |
| ) |
|
|
| self._maybe_log_align_stats(predicted_action_embeddings, gt_action_embeddings) |
|
|
| align_loss = F.l1_loss(predicted_action_embeddings, gt_action_embeddings.float().detach()) |
| recon_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=gt_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
| predict_loss = self.action_model.recon_loss_from_embedding( |
| actions=actions_target, |
| action_embedding=predicted_action_embeddings, |
| t=t, |
| noise=noise, |
| ) |
|
|
| return { |
| "align_loss": align_loss, |
| "recon_loss": recon_loss, |
| "predict_loss": predict_loss, |
| } |
|
|
| @torch.inference_mode() |
| def predict_action( |
| self, |
| examples: List[dict] = None, |
| embodiment_tag: Optional[str] = None, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """ |
| 推理:单次前向直接回归未来动作(无扩散采样)。 |
| |
| Steps: |
| 1. Resize images to training resolution (if specified) |
| 2. Encode with QwenVL (hidden states retained) |
| |
| Args: |
| examples: List of example dicts containing image, lang, etc. |
| embodiment_tag: Optional embodiment tag (e.g., "franka", "oxe_rt1", "oxe_bridge"). |
| If provided, will extract valid action dimensions based on ACTION_REPRESENTATION_SLICES. |
| If None, returns full unified action representation. |
| |
| Returns: |
| dict: |
| normalized_actions (np.ndarray): Shape [B, T, action_dim], diffusion-sampled normalized actions. |
| If embodiment_tag is provided, shape is [B, T, valid_dim] where |
| valid_dim is determined by ACTION_REPRESENTATION_SLICES[embodiment_tag]. |
| """ |
| from deployment.model_server.tools.image_tools import to_pil_preserve |
| batch_images = [to_pil_preserve(example["image"]) for example in examples] |
| instructions = [example["lang"] for example in examples] |
|
|
| dataset_ids = [example.get("dataset_id") for example in examples] |
|
|
| train_obs_image_size = getattr(self.config.datasets.vla_data, "image_size", None) |
| if train_obs_image_size: |
| batch_images = resize_images(batch_images, target_size=train_obs_image_size) |
|
|
| |
| qwen_inputs = self.qwen_vl_interface.build_qwenvl_inputs( |
| images=batch_images, |
| instructions=instructions, |
| ) |
| |
| if "input_ids" in qwen_inputs: |
| dataset_ids_tensor = torch.tensor( |
| dataset_ids, device=qwen_inputs["input_ids"].device, dtype=torch.long |
| ) |
| ds_embeds = self.dataset_embed(dataset_ids_tensor).view( |
| len(dataset_ids), self.num_data_tokens, self.llm_hidden_size |
| ) |
| token_embeds = self.qwen_vl_interface.model.get_input_embeddings()(qwen_inputs["input_ids"]) |
| query_embeds = self.query_token.expand(len(dataset_ids), -1, -1) |
| qwen_inputs["inputs_embeds"] = torch.cat((ds_embeds, token_embeds, query_embeds), dim=1) |
| qwen_inputs.pop("input_ids") |
| if "attention_mask" in qwen_inputs: |
| prefix_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], self.num_data_tokens), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| query_mask = torch.ones( |
| (qwen_inputs["attention_mask"].shape[0], 1), |
| device=qwen_inputs["attention_mask"].device, |
| dtype=qwen_inputs["attention_mask"].dtype, |
| ) |
| qwen_inputs["attention_mask"] = torch.cat( |
| (prefix_mask, qwen_inputs["attention_mask"], query_mask), dim=1 |
| ) |
| if "position_ids" in qwen_inputs: |
| prefix_pos = torch.arange( |
| self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ).unsqueeze(0).expand(qwen_inputs["position_ids"].shape[0], -1) |
| query_pos = ( |
| torch.full( |
| (qwen_inputs["position_ids"].shape[0], 1), |
| qwen_inputs["position_ids"].shape[1] + self.num_data_tokens, |
| device=qwen_inputs["position_ids"].device, |
| dtype=qwen_inputs["position_ids"].dtype, |
| ) |
| ) |
| qwen_inputs["position_ids"] = torch.cat( |
| (prefix_pos, qwen_inputs["position_ids"] + self.num_data_tokens, query_pos), dim=1 |
| ) |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| qwenvl_outputs = self.qwen_vl_interface( |
| **qwen_inputs, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| last_hidden_states = qwenvl_outputs.hidden_states[-1] |
|
|
| if "attention_mask" in qwen_inputs: |
| |
| last_token_indices = self._get_last_nonpad_indices(qwen_inputs["attention_mask"]) |
| batch_indices = torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device) |
| action_token_hidden = last_hidden_states[batch_indices, last_token_indices] |
| else: |
| action_token_hidden = last_hidden_states[:, -1, :] |
|
|
| predicted_action_embeddings = self.action_embed_projector(action_token_hidden).float() |
| |
| predicted_action_embeddings = F.normalize(predicted_action_embeddings, p=2, dim=-1) |
|
|
| |
| with torch.autocast("cuda", dtype=torch.float32): |
| pred_actions = self.action_model.decode_actions( |
| predicted_action_embeddings, |
| chunk_size=self.chunk_size |
| ) |
|
|
| normalized_actions = pred_actions.detach().cpu().numpy() |
| |
| |
| if embodiment_tag is not None: |
| if embodiment_tag not in ACTION_REPRESENTATION_SLICES: |
| raise ValueError( |
| f"Unknown embodiment tag '{embodiment_tag}'. " |
| f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES.keys())}" |
| ) |
| |
| |
| target_slice = ACTION_REPRESENTATION_SLICES[embodiment_tag] |
| |
| |
| normalized_actions = normalized_actions[..., target_slice] |
| |
| return {"normalized_actions": normalized_actions} |
|
|
|
|
| if __name__ == "__main__": |
|
|
| from omegaconf import OmegaConf |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config_yaml", type=str, |
| default="/fsx/home/yfang/projects/LearnLatent/starVLA/config/training/starvla_train_qwenlatent_oxe.yaml", |
| help="Path to YAML config") |
| args, clipargs = parser.parse_known_args() |
|
|
| cfg = OmegaConf.load(args.config_yaml) |
| |
|
|
|
|
| model = QwenLatent(cfg) |
| |
| |
| print(model) |
|
|
| |
| image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) |
| |
| sample = { |
| "action": np.random.uniform(-1, 1, size=(15, 7)).astype(np.float16), |
| "image": [image], |
| "image_past_half": [image], |
| "image_past_one": [image], |
| "image_future": [image], |
| "lang": "put the ball on the table", |
| "state": np.random.uniform(-1, 1, size=(1, 8)).astype(np.float16), |
| } |
|
|
| batch = [sample, sample] |
| device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| forward_output = model(batch) |
| align_loss = forward_output['align_loss'] |
| recon_loss = forward_output['recon_loss'] |
| print(f"Align Loss: {align_loss.item()}") |
| print(f"Recon Loss: {recon_loss.item()}") |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |