|
|
from typing import Literal, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
from configuration_step_vl import StepRoboticsVisionEncoderConfig |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate last dimension halves (used by RoPE).""" |
|
|
x = rearrange(x, "... (d r) -> ... d r", r=2) |
|
|
x1, x2 = x.unbind(dim=-1) |
|
|
x = torch.stack((-x2, x1), dim=-1) |
|
|
return rearrange(x, "... d r -> ... (d r)") |
|
|
|
|
|
|
|
|
def apply_rotary_emb(freqs: torch.Tensor, |
|
|
t: torch.Tensor, |
|
|
start_index: int = 0, |
|
|
scale: float = 1.0, |
|
|
seq_dim: int = -2) -> torch.Tensor: |
|
|
"""Apply 2D rotary embeddings to queries / keys.""" |
|
|
dtype = t.dtype |
|
|
|
|
|
if t.ndim == 3: |
|
|
seq_len = t.shape[seq_dim] |
|
|
freqs = freqs[-seq_len:] |
|
|
|
|
|
rot_dim = freqs.shape[-1] |
|
|
end_index = start_index + rot_dim |
|
|
assert rot_dim <= t.shape[-1], ( |
|
|
f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}") |
|
|
|
|
|
t_left, t, t_right = ( |
|
|
t[..., :start_index], |
|
|
t[..., start_index:end_index], |
|
|
t[..., end_index:], |
|
|
) |
|
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
|
|
out = torch.cat((t_left, t, t_right), dim=-1) |
|
|
return out.type(dtype) |
|
|
|
|
|
|
|
|
class EncoderRope2D(nn.Module): |
|
|
"""Cacheable 2D rotary positional embedding.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
max_grid_height: int, |
|
|
max_grid_width: int, |
|
|
use_cls_token: bool = False, |
|
|
theta: Union[int, float] = 10000, |
|
|
max_freq: int = 10, |
|
|
num_freqs: int = 1, |
|
|
theta_rescale_factor: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_grid_height = max_grid_height |
|
|
self.max_grid_width = max_grid_width |
|
|
self.use_cls_token = use_cls_token |
|
|
self.theta = theta * theta_rescale_factor**(dim / (dim - 2)) |
|
|
self.max_freq = max_freq |
|
|
self.num_freqs = num_freqs |
|
|
cache = self._compute_2d_freqs() |
|
|
self.register_buffer("freqs_cache", cache, persistent=False) |
|
|
|
|
|
def _compute_inv_freq(self, base: Union[int, float], |
|
|
dim: int) -> torch.Tensor: |
|
|
|
|
|
freqs = 1.0 / (base**( |
|
|
torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
|
|
return freqs |
|
|
|
|
|
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor): |
|
|
freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype), |
|
|
inv_freq) |
|
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2) |
|
|
return freqs |
|
|
|
|
|
def _compute_2d_freqs(self) -> torch.Tensor: |
|
|
grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float) |
|
|
grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float) |
|
|
if self.use_cls_token: |
|
|
grid_h_range += 1 |
|
|
grid_w_range += 1 |
|
|
inv_freq = self._compute_inv_freq(self.theta, self.dim // 2) |
|
|
freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand( |
|
|
self.max_grid_height, self.max_grid_width, -1) |
|
|
freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand( |
|
|
self.max_grid_height, self.max_grid_width, -1) |
|
|
freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape( |
|
|
self.max_grid_height * self.max_grid_width, -1) |
|
|
if self.use_cls_token: |
|
|
freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0) |
|
|
freqs = freqs[None, None, ...] |
|
|
return freqs |
|
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, |
|
|
grid_hw: tuple[int, int]): |
|
|
|
|
|
if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width: |
|
|
rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1) |
|
|
cols = torch.arange(grid_hw[1], device=q.device).view(1, -1) |
|
|
positions = (rows * self.max_grid_width + cols).reshape(-1).to( |
|
|
torch.long) |
|
|
if self.use_cls_token: |
|
|
positions = torch.cat( |
|
|
[torch.zeros(1, device=q.device), positions + 1], dim=0) |
|
|
freqs = self.freqs_cache.index_select(2, positions) |
|
|
else: |
|
|
freqs = self.freqs_cache |
|
|
q = apply_rotary_emb(freqs, q) |
|
|
k = apply_rotary_emb(freqs, k) |
|
|
return q, k |
|
|
|
|
|
|
|
|
class EncoderLayerScale(nn.Module): |
|
|
"""Per-channel residual scaling used when ls_init_value is set.""" |
|
|
|
|
|
def __init__(self, dim: int, init_values: float): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(torch.full((dim,), init_values)) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
return hidden_states * self.gamma |
|
|
|
|
|
|
|
|
class EncoderMLP(nn.Module): |
|
|
"""Feed-forward network used inside each transformer block.""" |
|
|
|
|
|
def __init__(self, hidden_size: int, intermediate_size: int, |
|
|
hidden_act: str): |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True) |
|
|
self.act_fn = ACT2FN[hidden_act] |
|
|
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states))) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class EncoderVisionAttention(nn.Module): |
|
|
"""Multi-head self attention with optional 2D RoPE.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_heads: int, |
|
|
max_grid_height: int, |
|
|
max_grid_width: int, |
|
|
use_cls_token: bool = False, |
|
|
use_rope2d: bool = True, |
|
|
rope_theta: Union[int, float] = 10000, |
|
|
rope_max_freq: int = 10, |
|
|
rope_num_freqs: int = 1, |
|
|
rope_theta_rescale_factor: float = 1.0, |
|
|
rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang", |
|
|
): |
|
|
super().__init__() |
|
|
if hidden_size % num_heads != 0: |
|
|
raise ValueError( |
|
|
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})." |
|
|
) |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = hidden_size // num_heads |
|
|
self.scale = self.head_dim**-0.5 |
|
|
self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size)) |
|
|
self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3)) |
|
|
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True) |
|
|
|
|
|
self.rope = None |
|
|
if use_rope2d: |
|
|
self.rope = EncoderRope2D( |
|
|
dim=self.head_dim, |
|
|
max_grid_height=max_grid_height, |
|
|
max_grid_width=max_grid_width, |
|
|
use_cls_token=use_cls_token, |
|
|
theta=rope_theta, |
|
|
max_freq=rope_max_freq, |
|
|
num_freqs=rope_num_freqs, |
|
|
theta_rescale_factor=rope_theta_rescale_factor, |
|
|
freqs_for=rope_freqs_for, |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor: |
|
|
bsz, seq_len, _ = hidden_states.shape |
|
|
qkv = F.linear( |
|
|
hidden_states, |
|
|
self.in_proj_weight, |
|
|
self.in_proj_bias, |
|
|
) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
|
|
|
q = q.view(bsz, seq_len, self.num_heads, |
|
|
self.head_dim).transpose(1, 2) |
|
|
k = k.view(bsz, seq_len, self.num_heads, |
|
|
self.head_dim).transpose(1, 2) |
|
|
if self.rope is not None: |
|
|
q, k = self.rope(q, k, grid_hw=grid_hw) |
|
|
v = v.view(bsz, seq_len, self.num_heads, |
|
|
self.head_dim).transpose(1, 2) |
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
q, k, v, is_causal=False, scale=self.scale) |
|
|
attn_output = attn_output.transpose(1, 2).reshape( |
|
|
bsz, seq_len, self.num_heads * self.head_dim) |
|
|
return self.out_proj(attn_output) |
|
|
|
|
|
|
|
|
class EncoderVisionBlock(nn.Module): |
|
|
"""A single Vision Transformer block (self-attention + MLP).""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float, |
|
|
hidden_act: str, |
|
|
layer_norm_eps: float, |
|
|
ls_init_value: Optional[float] = None, |
|
|
max_grid_height: Optional[int] = None, |
|
|
max_grid_width: Optional[int] = None, |
|
|
use_cls_token: bool = False, |
|
|
use_rope2d: bool = True, |
|
|
rope_kwargs: Optional[dict] = None, |
|
|
): |
|
|
super().__init__() |
|
|
rope_kwargs = rope_kwargs or {} |
|
|
self.attn = EncoderVisionAttention( |
|
|
hidden_size, |
|
|
num_heads, |
|
|
max_grid_height=max_grid_height, |
|
|
max_grid_width=max_grid_width, |
|
|
use_cls_token=use_cls_token, |
|
|
use_rope2d=use_rope2d, |
|
|
**rope_kwargs, |
|
|
) |
|
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
|
|
|
|
|
intermediate = int(hidden_size * mlp_ratio) |
|
|
self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act) |
|
|
|
|
|
self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value) |
|
|
self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, |
|
|
grid_hw: tuple[int, int]) -> torch.Tensor: |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.ln_1(hidden_states) |
|
|
hidden_states = self.attn(hidden_states, grid_hw=grid_hw) |
|
|
hidden_states = residual + self.ls_1(hidden_states) |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.ln_2(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + self.ls_2(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class EncoderVisionTransformer(nn.Module): |
|
|
"""Stack of encoder blocks parameterised by Step35VisionEncoderConfig.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
depth: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float, |
|
|
hidden_act: str, |
|
|
layer_norm_eps: float, |
|
|
ls_init_value: Optional[float] = None, |
|
|
max_grid_height: Optional[int] = None, |
|
|
max_grid_width: Optional[int] = None, |
|
|
use_cls_token: bool = False, |
|
|
use_rope2d: bool = True, |
|
|
rope_kwargs: Optional[dict] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers = depth |
|
|
rope_kwargs = rope_kwargs or {} |
|
|
self.resblocks = nn.ModuleList([ |
|
|
EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act, |
|
|
layer_norm_eps, |
|
|
max_grid_height=max_grid_height, |
|
|
max_grid_width=max_grid_width, |
|
|
use_cls_token=use_cls_token, |
|
|
use_rope2d=use_rope2d, |
|
|
ls_init_value=ls_init_value, |
|
|
rope_kwargs=rope_kwargs) |
|
|
for _ in range(depth) |
|
|
]) |
|
|
|
|
|
def forward(self, |
|
|
hidden_states: torch.Tensor, |
|
|
grid_hw: tuple[int, int]) -> torch.Tensor: |
|
|
for block in self.resblocks: |
|
|
hidden_states = block(hidden_states, grid_hw=grid_hw) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class StepRoboticsVisionEncoder(nn.Module): |
|
|
""" |
|
|
Vision encoder built from StepRoboticsVisionEncoderConfig. |
|
|
|
|
|
The encoder performs patch embedding followed by a stack of transformer |
|
|
blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and |
|
|
StepRoboticVLConfig.vision_config) are expected. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: StepRoboticsVisionEncoderConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
|
|
|
self.hidden_size = config.width |
|
|
self.num_heads = config.heads |
|
|
self.num_hidden_layers = config.layers |
|
|
self.patch_size = config.patch_size |
|
|
self.image_size = config.image_size |
|
|
self.use_cls_token = getattr(config, "use_cls_token", False) |
|
|
self.use_rope2d = getattr(config, "use_rope2d", True) |
|
|
self.use_abs_posemb = getattr(config, "use_abs_posemb", True) |
|
|
self.layer_norm_eps = config.layer_norm_eps |
|
|
self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536) |
|
|
self.ls_init_value = getattr(config, "ls_init_value", None) |
|
|
self.hidden_act = config.hidden_act |
|
|
self.use_ln_pre = getattr(config, "use_ln_pre", False) |
|
|
self.use_ln_post = getattr(config, "use_ln_post", True) |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels=config.num_channels, |
|
|
out_channels=self.hidden_size, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_size, |
|
|
bias=False) |
|
|
|
|
|
self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity() |
|
|
self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity() |
|
|
|
|
|
grid_size = self.image_size // self.patch_size |
|
|
self.base_grid = (grid_size, grid_size) |
|
|
|
|
|
if self.use_cls_token: |
|
|
self.class_embedding = nn.Parameter( |
|
|
torch.randn(self.hidden_size) * (self.hidden_size**-0.5)) |
|
|
else: |
|
|
self.class_embedding = None |
|
|
|
|
|
if self.use_abs_posemb: |
|
|
self.posemb_grid_size = self.image_size // self.patch_size |
|
|
self.positional_embedding = nn.Parameter( |
|
|
(self.hidden_size**-0.5) * torch.randn( |
|
|
int(self.use_cls_token) + self.posemb_grid_size**2, |
|
|
self.hidden_size, |
|
|
)) |
|
|
|
|
|
self.transformer = EncoderVisionTransformer( |
|
|
embed_dim=self.hidden_size, |
|
|
depth=self.num_hidden_layers, |
|
|
num_heads=self.num_heads, |
|
|
mlp_ratio=self.mlp_ratio, |
|
|
hidden_act=self.hidden_act, |
|
|
layer_norm_eps=self.layer_norm_eps, |
|
|
ls_init_value=self.ls_init_value, |
|
|
max_grid_height=self.base_grid[0], |
|
|
max_grid_width=self.base_grid[1], |
|
|
use_cls_token=self.use_cls_token, |
|
|
use_rope2d=self.use_rope2d, |
|
|
rope_kwargs={ |
|
|
"rope_theta": getattr(config, "rope_theta", 10000), |
|
|
"rope_max_freq": getattr(config, "rope_max_freq", 10), |
|
|
"rope_num_freqs": getattr(config, "rope_num_freqs", 1), |
|
|
"rope_theta_rescale_factor": |
|
|
getattr(config, "rope_theta_rescale_factor", 1.0), |
|
|
"rope_freqs_for": getattr(config, "rope_freqs_for", "lang"), |
|
|
}, |
|
|
) |
|
|
self.vit_downsampler1 = nn.Conv2d(self.hidden_size, |
|
|
self.hidden_size * 2, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1) |
|
|
self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2, |
|
|
self.hidden_size * 4, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1) |
|
|
|
|
|
|
|
|
def sample_abs_posemb(self, grid_h: int, grid_w: int): |
|
|
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: |
|
|
return self.positional_embedding[None, ...] |
|
|
|
|
|
pos_embed = self.positional_embedding |
|
|
if self.use_cls_token: |
|
|
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] |
|
|
|
|
|
pos_embed = (pos_embed.reshape(1, self.posemb_grid_size, |
|
|
self.posemb_grid_size, |
|
|
-1).permute(0, 3, 1, 2).contiguous()) |
|
|
pos_embed = F.interpolate(pos_embed, |
|
|
size=(grid_h, grid_w), |
|
|
mode="bilinear", |
|
|
align_corners=False) |
|
|
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size) |
|
|
|
|
|
if self.use_cls_token: |
|
|
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) |
|
|
|
|
|
return pos_embed[None, ...] |
|
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
pixel_values: Image tensor of shape (B, C, H, W). |
|
|
layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks). |
|
|
strip_cls_token: If True and cls token is used, remove it from output. |
|
|
""" |
|
|
bsz, _, height, width = pixel_values.shape |
|
|
grid_h, grid_w = height // self.patch_size, width // self.patch_size |
|
|
|
|
|
hidden_state = self.conv1(pixel_values) |
|
|
hidden_state = hidden_state.flatten(2).transpose(1, 2) |
|
|
|
|
|
if self.use_cls_token: |
|
|
cls_token = self.class_embedding.view(1, 1, |
|
|
-1).expand(bsz, -1, -1) |
|
|
hidden_state = torch.cat([cls_token, hidden_state], dim=1) |
|
|
|
|
|
if self.use_abs_posemb: |
|
|
pos_emb = self.sample_abs_posemb(grid_h, grid_w) |
|
|
hidden_state = hidden_state + pos_emb |
|
|
hidden_state = self.ln_pre(hidden_state) |
|
|
hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w)) |
|
|
|
|
|
if self.use_ln_post: |
|
|
hidden_state = self.ln_post(hidden_state) |
|
|
|
|
|
if self.use_cls_token: |
|
|
hidden_state = hidden_state[:, 1:, :] |
|
|
|
|
|
return hidden_state |
|
|
|