| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from typing import Any |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from transformers.activations import ACT2FN |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| |
|
| | from .configuration_eo1 import EO1VisionFlowMatchingConfig |
| | from .modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def create_sinusoidal_pos_embedding( |
| | time: torch.tensor, |
| | dimension: int, |
| | min_period: float = 4e-3, |
| | max_period: float = 4.0, |
| | device="cpu", |
| | ) -> Tensor: |
| | """Computes sine-cosine positional embedding vectors for scalar positions.""" |
| | if dimension % 2 != 0: |
| | raise ValueError(f"dimension ({dimension}) must be divisible by 2") |
| |
|
| | if time.ndim != 1: |
| | raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") |
| |
|
| | fraction = torch.linspace(0.0, 1.0, dimension // 2, device=device) |
| | period = min_period * (max_period / min_period) ** fraction |
| |
|
| | scaling_factor = 1.0 / period * 2 * math.pi |
| | sin_input = scaling_factor[None, :] * time[:, None] |
| | pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) |
| | return pos_emb |
| |
|
| |
|
| | @dataclass |
| | class EO1VisionFlowMatchingOutputWithPast(ModelOutput): |
| | loss: torch.FloatTensor | None = None |
| | fm_loss: torch.FloatTensor | None = None |
| | ar_loss: torch.FloatTensor | None = None |
| |
|
| | actions: torch.FloatTensor | None = None |
| | logits: torch.FloatTensor | None = None |
| |
|
| | past_key_values: list[torch.FloatTensor] | None = None |
| | hidden_states: tuple[torch.FloatTensor] | None = None |
| | attentions: tuple[torch.FloatTensor] | None = None |
| | rope_deltas: torch.LongTensor | None = None |
| |
|
| |
|
| | class EO1VisionActionProjector(torch.nn.Sequential): |
| | """This block implements the multi-layer perceptron (MLP) module.""" |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | num_layers: int = 2, |
| | activation_layer: str = "linear", |
| | bias: bool = True, |
| | device: Any = None, |
| | dtype: torch.dtype = torch.float32, |
| | ): |
| | layers = [] |
| | in_dim = in_channels |
| | hidden_channels = [in_dim] * (num_layers - 1) + [out_channels] |
| | for hidden_dim in hidden_channels[:-1]: |
| | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)) |
| | layers.append(ACT2FN[activation_layer]) |
| | in_dim = hidden_dim |
| | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device)) |
| | super().__init__(*layers) |
| |
|
| | @property |
| | def dtype(self): |
| | return self[0].weight.dtype |
| |
|
| |
|
| | class EO1VisionFlowMatchingModel(PreTrainedModel, GenerationMixin): |
| | config_class = EO1VisionFlowMatchingConfig |
| | supports_gradient_checkpointing = True |
| |
|
| | _supports_flash_attn = True |
| | _supports_sdpa = True |
| | _supports_attention_backend = True |
| | _can_compile_fullgraph = True |
| | _skip_keys_device_placement = "past_key_values" |
| |
|
| | def __init__( |
| | self, |
| | config: EO1VisionFlowMatchingConfig, |
| | vlm_backbone: Qwen2_5_VLForConditionalGeneration = None, |
| | ): |
| | super().__init__(config) |
| |
|
| | hidden_size = self.config.text_config.hidden_size |
| | max_action_dim = self.config.max_action_dim |
| | self.vlm_backbone = vlm_backbone or Qwen2_5_VLForConditionalGeneration(self.config) |
| | self.state_proj = nn.Linear(max_action_dim, hidden_size) |
| | self.action_in_proj = nn.Linear(max_action_dim, hidden_size) |
| | self.action_out_proj = EO1VisionActionProjector( |
| | hidden_size, |
| | max_action_dim, |
| | self.config.num_action_layers, |
| | self.config.action_act, |
| | ) |
| | self.action_time_mlp_in = nn.Linear(hidden_size * 2, hidden_size) |
| | self.action_time_mlp_out = nn.Linear(hidden_size, hidden_size) |
| |
|
| | self.post_init() |
| | self.to_float32_flow_matching_head() |
| |
|
| | def get_input_embeddings(self): |
| | return self.vlm_backbone.get_input_embeddings() |
| |
|
| | def to_float32_flow_matching_head(self): |
| | self.action_out_proj = self.action_out_proj.to(dtype=torch.float32) |
| | self.action_time_mlp_in = self.action_time_mlp_in.to(dtype=torch.float32) |
| | self.action_time_mlp_out = self.action_time_mlp_out.to(dtype=torch.float32) |
| | self.state_proj = self.state_proj.to(dtype=torch.float32) |
| | self.action_in_proj = self.action_in_proj.to(dtype=torch.float32) |
| |
|
| | def sample_noise(self, shape, device): |
| | noise = torch.normal( |
| | mean=0.0, |
| | std=1.0, |
| | size=shape, |
| | dtype=torch.float32, |
| | device=device, |
| | ) |
| | return noise |
| |
|
| | def sample_time(self, bsize, device): |
| | beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) |
| | time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) |
| | time = time_beta * 0.999 + 0.001 |
| | return time |
| |
|
| | def replace_special_embeddings( |
| | self, |
| | input_ids: torch.LongTensor, |
| | inputs_embeds: torch.FloatTensor, |
| | special_features: torch.FloatTensor = None, |
| | special_token_ids: torch.LongTensor = None, |
| | ) -> torch.LongTensor: |
| | """Replace the special embeddings with the special features.""" |
| | if special_features is not None and special_token_ids is not None: |
| | n_special_tokens = (input_ids == special_token_ids).sum().item() |
| | n_special_features = special_features.shape[0] |
| | assert n_special_tokens == n_special_features, ( |
| | f"Special features and special tokens {special_token_ids} do not match: \ |
| | tokens: {n_special_tokens}, features {n_special_features}" |
| | ) |
| | mask = input_ids == special_token_ids |
| | mask_unsqueezed = mask.unsqueeze(-1) |
| | mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| | special_mask = mask_expanded.to(inputs_embeds.device) |
| | special_features = special_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| | inputs_embeds = inputs_embeds.masked_scatter(special_mask, special_features) |
| | return inputs_embeds, None |
| |
|
| | def embed_prefix( |
| | self, |
| | input_ids: torch.LongTensor, |
| | inputs_embeds: torch.FloatTensor | None = None, |
| | pixel_values: torch.Tensor | None = None, |
| | pixel_values_videos: torch.FloatTensor | None = None, |
| | image_grid_thw: torch.LongTensor | None = None, |
| | video_grid_thw: torch.LongTensor | None = None, |
| | states: torch.Tensor | None = None, |
| | ) -> tuple[torch.FloatTensor, torch.Tensor, torch.Tensor]: |
| | """Embed the suffix""" |
| | if inputs_embeds is None: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | if pixel_values is not None: |
| | image_embeds = self.vlm_backbone.get_image_features(pixel_values, image_grid_thw) |
| | image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) |
| | image_mask, _ = self.vlm_backbone.get_placeholder_mask( |
| | input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds |
| | ) |
| | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| |
|
| | if pixel_values_videos is not None: |
| | video_embeds = self.vlm_backbone.get_video_features(pixel_values_videos, video_grid_thw) |
| | video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) |
| | _, video_mask = self.vlm_backbone.get_placeholder_mask( |
| | input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds |
| | ) |
| | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
| |
|
| | if states is not None: |
| | states = states.type(self.state_proj.weight.dtype) |
| | state_embs = self.state_proj(states) |
| | inputs_embeds, _ = self.replace_special_embeddings( |
| | input_ids, inputs_embeds, state_embs, self.config.state_token_id |
| | ) |
| | return inputs_embeds |
| |
|
| | def embed_suffix( |
| | self, |
| | timestep: torch.Tensor, |
| | noisy_actions: torch.Tensor, |
| | ) -> torch.FloatTensor: |
| | """Embed the suffix""" |
| | time_embs = create_sinusoidal_pos_embedding( |
| | timestep, |
| | self.config.text_config.hidden_size, |
| | device=noisy_actions.device, |
| | ) |
| | time_embs = time_embs.type(noisy_actions.dtype) |
| | noisy_actions = noisy_actions.type(self.action_in_proj.weight.dtype) |
| | action_embs = self.action_in_proj(noisy_actions) |
| | time_embs = time_embs[:, None, :].expand_as(action_embs) |
| |
|
| | action_time_embs = torch.cat([action_embs, time_embs], dim=2) |
| | action_time_embs = self.action_time_mlp_in(action_time_embs) |
| | action_time_embs = F.silu(action_time_embs) |
| | action_time_embs = self.action_time_mlp_out(action_time_embs) |
| | return action_time_embs |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | position_ids: torch.LongTensor | None = None, |
| | past_key_values: list[torch.FloatTensor] | None = None, |
| | inputs_embeds: torch.FloatTensor | None = None, |
| | labels: torch.LongTensor | None = None, |
| | use_cache: bool | None = None, |
| | output_attentions: bool | None = None, |
| | output_hidden_states: bool | None = None, |
| | pixel_values: torch.Tensor | None = None, |
| | pixel_values_videos: torch.FloatTensor | None = None, |
| | image_grid_thw: torch.LongTensor | None = None, |
| | video_grid_thw: torch.LongTensor | None = None, |
| | rope_deltas: torch.LongTensor | None = None, |
| | cache_position: torch.LongTensor | None = None, |
| | second_per_grid_ts: torch.Tensor | None = None, |
| | logits_to_keep: int | torch.Tensor = 0, |
| | states: torch.Tensor | None = None, |
| | actions: torch.Tensor | None = None, |
| | action_is_pad: torch.Tensor | None = None, |
| | **kwargs, |
| | ) -> EO1VisionFlowMatchingOutputWithPast: |
| | """multi-modal forward pass, including image, video, state, action, and language.""" |
| | inputs_embeds = self.embed_prefix( |
| | input_ids, |
| | inputs_embeds, |
| | pixel_values, |
| | pixel_values_videos, |
| | image_grid_thw, |
| | video_grid_thw, |
| | states, |
| | ) |
| |
|
| | if actions is not None: |
| | noise_mask = input_ids == self.config.action_token_id |
| | pass_mask = input_ids == self.config.action_pass_id |
| | mask = noise_mask | pass_mask |
| |
|
| | pass_mask_in_action = pass_mask[mask] |
| | pass_mask_in_action = pass_mask_in_action.reshape(*actions.shape[:2], 1) |
| |
|
| | time = self.sample_time(actions.shape[0], inputs_embeds.device) |
| | time_expanded = time[:, None, None].repeat(1, actions.shape[1], 1) |
| | time_expanded[pass_mask_in_action] = 0.0 |
| |
|
| | noise = self.sample_noise(actions.shape, inputs_embeds.device) |
| | x_t = time_expanded * noise + (1 - time_expanded) * actions |
| | u_t = noise - actions |
| |
|
| | action_time_embs = self.embed_suffix(time, x_t) |
| | mask_unsqueezed = mask.unsqueeze(-1) |
| | mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| | action_mask = mask_expanded.to(inputs_embeds.device) |
| |
|
| | action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype) |
| | inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attention_mask.to(inputs_embeds.device) |
| |
|
| | if position_ids is None: |
| | prefill_noncompiled_stage = (cache_position is not None and cache_position[0] == 0) or ( |
| | past_key_values is None or past_key_values.get_seq_length() == 0 |
| | ) |
| | if prefill_noncompiled_stage or self.vlm_backbone.rope_deltas is None: |
| | position_ids, rope_deltas = self.vlm_backbone.get_rope_index( |
| | input_ids, |
| | image_grid_thw, |
| | video_grid_thw, |
| | second_per_grid_ts=second_per_grid_ts, |
| | attention_mask=attention_mask, |
| | ) |
| | self.vlm_backbone.rope_deltas = rope_deltas |
| | else: |
| | batch_size, seq_length, _ = inputs_embeds.shape |
| | position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| | position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) |
| | if cache_position is not None: |
| | delta = (cache_position[0] + self.vlm_backbone.rope_deltas).to(inputs_embeds.device) |
| | else: |
| | delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device) |
| | delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) |
| | position_ids += delta.to(position_ids.device) |
| |
|
| | |
| | output_actions = None |
| | if not (self.training or states is None): |
| | output_actions, outputs = self.sample_actions( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | cache_position=cache_position, |
| | states=states, |
| | ) |
| | else: |
| | outputs = self.vlm_backbone.model( |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| |
|
| | |
| | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| | logits = self.vlm_backbone.lm_head(hidden_states[:, slice_indices, :]) |
| |
|
| | loss = None |
| | fm_loss = None |
| | v_t = None |
| | if actions is not None: |
| | action_time_embs = hidden_states[action_mask[..., 0]] |
| | action_time_embs = action_time_embs.type(self.action_out_proj.dtype) |
| |
|
| | v_t = self.action_out_proj(action_time_embs) |
| | u_t = u_t.reshape(v_t.shape) |
| | v_t = v_t.type(u_t.dtype) |
| |
|
| | losses = F.mse_loss(u_t, v_t, reduction="none") |
| | if action_is_pad is not None: |
| | in_episode_bound = (~action_is_pad).reshape(-1, 1) |
| | losses = losses * in_episode_bound |
| |
|
| | in_denoise_bound = (~pass_mask_in_action).reshape(-1, 1) |
| | losses = losses * in_denoise_bound |
| |
|
| | fm_loss = losses.mean() |
| | loss = fm_loss |
| |
|
| | ar_loss = None |
| | if labels is not None: |
| | ar_loss = self.vlm_backbone.loss_function( |
| | logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs |
| | ) |
| | loss = loss + ar_loss if loss is not None else ar_loss |
| |
|
| | return EO1VisionFlowMatchingOutputWithPast( |
| | loss=loss, |
| | fm_loss=fm_loss, |
| | ar_loss=ar_loss, |
| | actions=output_actions, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | rope_deltas=self.vlm_backbone.rope_deltas, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def sample_actions( |
| | self, |
| | input_ids: torch.LongTensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | pixel_values: torch.Tensor | None = None, |
| | image_grid_thw: torch.LongTensor | None = None, |
| | states: torch.Tensor | None = None, |
| | **kwargs, |
| | ) -> Tensor: |
| | """Sample actions from the model.""" |
| |
|
| | |
| | position_ids, _ = self.vlm_backbone.get_rope_index( |
| | input_ids, |
| | image_grid_thw=image_grid_thw, |
| | attention_mask=attention_mask, |
| | ) |
| |
|
| | |
| | inputs_embeds = self.embed_prefix( |
| | input_ids, |
| | pixel_values=pixel_values, |
| | image_grid_thw=image_grid_thw, |
| | states=states, |
| | ) |
| |
|
| | |
| | seq_len = input_ids.shape[-1] |
| | chunk_size = self.config.action_chunk_size |
| | suffix_len = -1 |
| | prefix_len = seq_len - chunk_size - 1 |
| |
|
| | outputs = self.vlm_backbone.model( |
| | position_ids=position_ids[..., :prefix_len], |
| | attention_mask=attention_mask[:, :prefix_len], |
| | inputs_embeds=inputs_embeds[:, :prefix_len], |
| | use_cache=True, |
| | ) |
| |
|
| | |
| | device = states.device |
| | actions_shape = (states.shape[0], chunk_size, self.config.max_action_dim) |
| | noise = self.sample_noise(actions_shape, device) |
| |
|
| | x_t = noise.type(self.action_in_proj.weight.dtype) |
| | dt = torch.tensor(-1.0 / self.config.num_denoise_steps, device=device) |
| | time = torch.ones(inputs_embeds.shape[0], device=device) |
| | past_key_values = outputs.past_key_values |
| |
|
| | action_mask = input_ids == self.config.action_token_id |
| | while time >= -dt / 2: |
| | action_time_embs = self.embed_suffix(time, x_t) |
| | inputs_embeds[action_mask] = action_time_embs.to(inputs_embeds.dtype) |
| |
|
| | past_key_values.crop(prefix_len) |
| | outputs = self.vlm_backbone.model( |
| | position_ids=position_ids[..., prefix_len:suffix_len], |
| | attention_mask=attention_mask[:, :suffix_len], |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds[:, prefix_len:suffix_len], |
| | use_cache=True, |
| | ) |
| | action_time_embs = outputs.last_hidden_state[:, :chunk_size] |
| | action_time_embs = action_time_embs.type(self.action_out_proj.dtype) |
| | v_t = self.action_out_proj(action_time_embs) |
| |
|
| | x_t += dt * v_t.reshape(x_t.shape) |
| | time += dt |
| | return x_t |
| |
|
| | def prepare_inputs_for_generation(self, *args, **kwargs): |
| | return self.vlm_backbone.prepare_inputs_for_generation(*args, **kwargs) |
| |
|
| | def _expand_inputs_for_generation(self, *args, **kwargs): |
| | return self.vlm_backbone._expand_inputs_for_generation(*args, **kwargs) |
| |
|
| |
|
| | EO1VisionFlowMatchingModel.register_for_auto_class() |
| |
|