MLX
MLX_SAM3 / prompt_encoder.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
SAM3 Prompt Encoder - Complete MLX Implementation
Encodes different types of user prompts:
- Points (clicks): Positive/negative points with coordinates
- Boxes: Bounding box coordinates (top-left, bottom-right)
- Masks: Dense mask inputs
Outputs:
- Sparse embeddings: Point and box prompt embeddings
- Dense embeddings: Mask prompt embeddings
"""
import mlx.core as mx
import mlx.nn as nn
from mlx.nn import Module
from typing import Optional, Tuple, List
import math
class PositionEmbeddingRandom(Module):
"""
Positional encoding using random spatial frequencies
Similar to Fourier features - maps 2D coordinates to high-dimensional space
using learned frequency basis.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None):
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.scale = scale
# Random frequency matrix
# Each row is a 2D frequency vector
self.positional_encoding_gaussian_matrix = mx.random.normal(
shape=(2, num_pos_feats)
) * scale
def _pe_encoding(self, coords: mx.array) -> mx.array:
"""
Positionally encode points normalized to [0, 1]
Args:
coords: (B, N, 2) coordinates in [0, 1] range
Returns:
(B, N, num_pos_feats * 2) positional encoding
"""
# coords is (B, N, 2)
# Multiply by frequency matrix: (B, N, 2) @ (2, num_pos_feats) -> (B, N, num_pos_feats)
coords_scaled = coords * 2 * math.pi
# Project through random frequencies
# coords_scaled: (B, N, 2), matrix: (2, num_pos_feats)
projected = coords_scaled @ self.positional_encoding_gaussian_matrix
# Apply sin and cos
sin_proj = mx.sin(projected)
cos_proj = mx.cos(projected)
# Concatenate: (B, N, num_pos_feats * 2)
return mx.concatenate([sin_proj, cos_proj], axis=-1)
def forward(self, size: Tuple[int, int]) -> mx.array:
"""
Generate positional encoding for a 2D grid
Args:
size: (H, W) grid size
Returns:
(H, W, C) positional encoding
"""
h, w = size
device = self.positional_encoding_gaussian_matrix.device
# Create coordinate grid
# y_embed: (H, W), x_embed: (H, W)
y_embed = mx.arange(h, dtype=mx.float32).reshape(-1, 1).broadcast_to((h, w))
x_embed = mx.arange(w, dtype=mx.float32).reshape(1, -1).broadcast_to((h, w))
# Normalize to [0, 1]
y_embed = y_embed / h
x_embed = x_embed / w
# Stack to (H, W, 2)
coords = mx.stack([x_embed, y_embed], axis=-1)
# Encode: (H, W, 2) -> (H, W, C)
# Add batch dimension, encode, remove batch dimension
coords = coords.reshape(1, h * w, 2)
pe = self._pe_encoding(coords)
pe = pe.reshape(h, w, -1)
return pe
def forward_with_coords(
self, coords_input: mx.array, image_size: Tuple[int, int]
) -> mx.array:
"""
Encode arbitrary point coordinates
Args:
coords_input: (B, N, 2) in pixel coordinates
image_size: (H, W) image dimensions for normalization
Returns:
(B, N, C) positional encodings
"""
# Normalize coordinates to [0, 1]
coords = coords_input.astype(mx.float32)
coords[:, :, 0] = coords[:, :, 0] / image_size[1] # x / W
coords[:, :, 1] = coords[:, :, 1] / image_size[0] # y / H
return self._pe_encoding(coords)
class PromptEncoder(Module):
"""
Complete SAM3 Prompt Encoder
Encodes prompts into embeddings for the mask decoder:
- Points: Sparse embeddings with learned type (positive/negative)
- Boxes: Sparse embeddings for corners (top-left, bottom-right)
- Masks: Dense embeddings from downsampled mask
Args:
embed_dim: Channel dimension for embeddings
image_embedding_size: Size of image embeddings from encoder
input_image_size: Original input image size
mask_in_chans: Input channels for mask encoder (default 16)
"""
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int = 16,
):
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
# Positional encoding for points and boxes
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
# Learnable embeddings for different prompt types
self.num_point_embeddings = 4 # pos, neg, top-left corner, bottom-right corner
self.point_embeddings = [
nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)
]
# Embedding for "no mask" case
self.not_a_point_embed = nn.Embedding(1, embed_dim)
# Mask downsampling encoder
# Downsample mask from input_image_size to image_embedding_size
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
nn.LayerNorm(mask_in_chans // 4),
nn.GELU(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
nn.LayerNorm(mask_in_chans),
nn.GELU(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
# No mask embedding (used when no mask prompt provided)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> mx.array:
"""
Get positional encoding for image embedding grid
Returns:
(H, W, C) dense positional encoding
"""
return self.pe_layer(self.image_embedding_size)
def _embed_points(
self,
points: mx.array,
labels: mx.array,
pad: bool,
) -> mx.array:
"""
Embed point prompts
Args:
points: (B, N, 2) point coordinates
labels: (B, N) point labels (0=negative, 1=positive)
pad: Whether to pad with "not a point" embedding
Returns:
(B, N, C) or (B, N+1, C) point embeddings
"""
# Add positional encoding to points
points = points + 0.5 # Shift to center of pixel
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
# Add learned type embedding based on label
# labels: 0 = negative, 1 = positive
B, N, C = point_embedding.shape
for b in range(B):
for n in range(N):
label = int(labels[b, n].item())
if label == 0:
# Negative point
type_embed = self.point_embeddings[0].weight
elif label == 1:
# Positive point
type_embed = self.point_embeddings[1].weight
else:
# Unknown, use negative
type_embed = self.point_embeddings[0].weight
point_embedding[b, n, :] = point_embedding[b, n, :] + type_embed.reshape(-1)
# Pad with "not a point" embedding if requested
if pad:
padding_point = self.not_a_point_embed.weight.reshape(1, 1, -1).broadcast_to(
(B, 1, C)
)
point_embedding = mx.concatenate([point_embedding, padding_point], axis=1)
return point_embedding
def _embed_boxes(self, boxes: mx.array) -> mx.array:
"""
Embed box prompts
Args:
boxes: (B, 4) boxes as [x0, y0, x1, y1]
Returns:
(B, 2, C) corner embeddings [top-left, bottom-right]
"""
B = boxes.shape[0]
boxes = boxes + 0.5 # Shift to pixel centers
# Split into corners: (B, 2, 2)
coords = mx.stack(
[
boxes[:, :2], # top-left [x0, y0]
boxes[:, 2:], # bottom-right [x1, y1]
],
axis=1,
)
# Get positional encoding for corners
corner_embedding = self.pe_layer.forward_with_coords(
coords, self.input_image_size
) # (B, 2, C)
# Add learned corner type embeddings
corner_embedding[:, 0, :] = corner_embedding[:, 0, :] + self.point_embeddings[2].weight.reshape(-1)
corner_embedding[:, 1, :] = corner_embedding[:, 1, :] + self.point_embeddings[3].weight.reshape(-1)
return corner_embedding
def _embed_masks(self, masks: mx.array) -> mx.array:
"""
Embed mask prompts
Args:
masks: (B, 1, H, W) dense masks
Returns:
(B, H_emb, W_emb, C) downsampled mask embeddings
"""
# Downsample mask to embedding size
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def forward(
self,
points: Optional[Tuple[mx.array, mx.array]] = None,
boxes: Optional[mx.array] = None,
masks: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
Encode prompts into sparse and dense embeddings
Args:
points: Optional tuple of (coords, labels)
- coords: (B, N, 2) point coordinates
- labels: (B, N) point labels (0=neg, 1=pos)
boxes: Optional (B, 4) boxes as [x0, y0, x1, y1]
masks: Optional (B, 1, H, W) mask prompts
Returns:
sparse_embeddings: (B, N_sparse, C) point/box embeddings
dense_embeddings: (B, H_emb, W_emb, C) mask embeddings
"""
bs = 1 # Default batch size
# Handle sparse prompts (points and boxes)
sparse_embeddings_list = []
if points is not None:
coords, labels = points
bs = coords.shape[0]
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings_list.append(point_embeddings)
if boxes is not None:
bs = boxes.shape[0]
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings_list.append(box_embeddings)
# Concatenate all sparse embeddings
if len(sparse_embeddings_list) > 0:
sparse_embeddings = mx.concatenate(sparse_embeddings_list, axis=1)
else:
# No sparse prompts - use "not a point" embedding
sparse_embeddings = self.not_a_point_embed.weight.reshape(
1, 1, -1
).broadcast_to((bs, 1, self.embed_dim))
# Handle dense prompts (masks)
if masks is not None:
bs = masks.shape[0]
dense_embeddings = self._embed_masks(masks)
else:
# No mask prompt - broadcast no_mask_embed to image embedding size
H, W = self.image_embedding_size
dense_embeddings = self.no_mask_embed.weight.reshape(
1, 1, 1, -1
).broadcast_to((bs, H, W, self.embed_dim))
return sparse_embeddings, dense_embeddings
def create_prompt_encoder(
embed_dim: int = 256,
image_embedding_size: Tuple[int, int] = (64, 64),
input_image_size: Tuple[int, int] = (1024, 1024),
) -> PromptEncoder:
"""
Factory function to create SAM3 prompt encoder
Args:
embed_dim: Embedding dimension
image_embedding_size: Size of vision encoder output
input_image_size: Size of input images
Returns:
PromptEncoder instance
"""
return PromptEncoder(
embed_dim=embed_dim,
image_embedding_size=image_embedding_size,
input_image_size=input_image_size,
)