alexnasa's picture
Upload 243 files
2568013 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any
from src.model.encoder.vggt.layers import PatchEmbed
from src.model.encoder.vggt.layers.block import Block
from src.model.encoder.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from src.model.encoder.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
logger = logging.getLogger(__name__)
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
class Aggregator(nn.Module):
"""
The Aggregator applies alternating-attention over input frames,
as described in VGGT: Visual Geometry Grounded Transformer.
Args:
img_size (int): Image size in pixels.
patch_size (int): Size of each patch for PatchEmbed.
embed_dim (int): Dimension of the token embeddings.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
num_register_tokens (int): Number of register tokens.
block_fn (nn.Module): The block type used for attention (Block by default).
qkv_bias (bool): Whether to include bias in QKV projections.
proj_bias (bool): Whether to include bias in the output projection.
ffn_bias (bool): Whether to include bias in MLP layers.
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
qk_norm (bool): Whether to apply QK normalization.
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
init_values (float): Init scale for layer scale.
"""
def __init__(
self,
img_size=518,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=4,
block_fn=Block,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
patch_embed="dinov2_vitl14_reg",
aa_order=["frame", "global"],
aa_block_size=1,
qk_norm=True,
rope_freq=100,
init_values=0.01,
):
super().__init__()
self.use_checkpoint = True
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
# Initialize rotary position embedding if frequency > 0
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
self.frame_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
self.global_blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
rope=self.rope,
)
for _ in range(depth)
]
)
self.depth = depth
self.aa_order = aa_order
self.patch_size = patch_size
self.aa_block_size = aa_block_size
# Validate that depth is divisible by aa_block_size
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
# Note: We have two camera tokens, one for the first frame and one for the rest
# The same applies for register tokens
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
# The patch tokens start after the camera and register tokens
self.patch_start_idx = 1 + num_register_tokens
# Initialize parameters with small values
nn.init.normal_(self.camera_token, std=1e-6)
nn.init.normal_(self.register_token, std=1e-6)
# Register normalization constants as buffers
for name, value in (
("_resnet_mean", _RESNET_MEAN),
("_resnet_std", _RESNET_STD),
):
self.register_buffer(
name,
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
persistent=False,
)
def __build_patch_embed__(
self,
patch_embed,
img_size,
patch_size,
num_register_tokens,
interpolate_antialias=True,
interpolate_offset=0.0,
block_chunks=0,
init_values=1.0,
embed_dim=1024,
):
"""
Build the patch embed layer. If 'conv', we use a
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
"""
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}
self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)
# Disable gradient updates for mask token
if hasattr(self.patch_embed, "mask_token"):
self.patch_embed.mask_token.requires_grad_(False)
def forward(
self,
images: torch.Tensor,
intermediate_layer_idx: Optional[List[int]] = None
) -> Tuple[List[torch.Tensor], int]:
"""
Args:
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
Returns:
(list[torch.Tensor], int):
The list of outputs from the attention blocks,
and the patch_start_idx indicating where patch tokens begin.
"""
B, S, C_in, H, W = images.shape
if C_in != 3:
raise ValueError(f"Expected 3 input channels, got {C_in}")
# Normalize images and reshape for patch embed
images = (images - self._resnet_mean) / self._resnet_std
# Reshape to [B*S, C, H, W] for patch embedding
images = images.view(B * S, C_in, H, W)
patch_tokens = self.patch_embed(images)
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
_, P, C = patch_tokens.shape
# Expand camera and register tokens to match batch size and sequence length
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
register_token = slice_expand_and_flatten(self.register_token, B, S)
# Concatenate special tokens with patch tokens
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
pos = None
if self.rope is not None:
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
if self.patch_start_idx > 0:
# do not use position embedding for special tokens (camera and register tokens)
# so set pos to 0 for the special tokens
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)
# update P because we added special tokens
_, P, C = tokens.shape
frame_idx = 0
global_idx = 0
output_list = []
layer_idx = 0
# Convert intermediate_layer_idx to a set for O(1) lookup
if intermediate_layer_idx is not None:
required_layers = set(intermediate_layer_idx)
# Always include the last layer for camera_head
required_layers.add(self.depth - 1)
for _ in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S, P, C, frame_idx, pos=pos
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S, P, C, global_idx, pos=pos
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
if intermediate_layer_idx is not None:
for i in range(len(frame_intermediates)):
current_layer = layer_idx + i
if current_layer in required_layers:
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
layer_idx += self.aa_block_size
else:
for i in range(len(frame_intermediates)):
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
del concat_inter
del frame_intermediates
del global_intermediates
return output_list, self.patch_start_idx
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
"""
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
"""
# If needed, reshape tokens or positions:
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.use_checkpoint:
tokens = torch.utils.checkpoint.checkpoint(
self.frame_blocks[frame_idx],
tokens,
pos,
use_reentrant=False,
)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
"""
Process global attention blocks. We keep tokens in shape (B, S*P, C).
"""
if tokens.shape != (B, S * P, C):
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
if pos is not None and pos.shape != (B, S * P, 2):
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
intermediates = []
# by default, self.aa_block_size=1, which processes one block at a time
for _ in range(self.aa_block_size):
if self.use_checkpoint:
tokens = torch.utils.checkpoint.checkpoint(
self.global_blocks[global_idx],
tokens,
pos,
use_reentrant=False,
)
else:
tokens = self.global_blocks[global_idx](tokens, pos=pos)
global_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, global_idx, intermediates
def slice_expand_and_flatten(token_tensor, B, S):
"""
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
1) Uses the first position (index=0) for the first frame only
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
3) Expands both to match batch size B
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
followed by (S-1) second-position tokens
5) Flattens to (B*S, X, C) for processing
Returns:
torch.Tensor: Processed tokens with shape (B*S, X, C)
"""
# Slice out the "query" tokens => shape (1, 1, ...)
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
# Slice out the "other" tokens => shape (1, S-1, ...)
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
# Concatenate => shape (B, S, ...)
combined = torch.cat([query, others], dim=1)
# Finally flatten => shape (B*S, ...)
combined = combined.view(B * S, *combined.shape[2:])
return combined