Upload 15 files
Browse filessam3_mlx/
βββ LICENSE (MIT)
βββ README.md (Professional docs with badges)
βββ CONTRIBUTING.md (Contribution guidelines)
βββ pyproject.toml (pip installation)
βββ requirements.txt
βββ .gitignore
β
βββ models/
β βββ attention.py (RoPE Multi-Head Attention)
β βββ hiera.py (Hierarchical Vision Encoder)
β βββ prompt_encoder.py (Point/Box/Mask encoding)
β βββ mask_decoder.py (Two-way transformer)
β βββ sam3.py (Complete SAM3 model)
β
βββ utils/
β βββ weights.py (Weight loading/saving)
β
βββ examples/
β βββ click_segment.py (Working demo)
β
βββ tests/
βββ test_models.py (Component validation)
βββ benchmark.py (Performance metrics)
- CONTRIBUTING.md +167 -0
- LICENSE +29 -0
- README.md +51 -0
- __init__.py +25 -0
- attention.py +215 -0
- benchmark.py +148 -0
- click_segment.py +258 -0
- hiera.py +352 -0
- mask_decoder.py +373 -0
- prompt_encoder.py +360 -0
- pyproject.toml +101 -0
- requirements.txt +8 -0
- sam3.py +357 -0
- test_models.py +255 -0
- weights.py +263 -0
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to SAM3 MLX
|
| 2 |
+
|
| 3 |
+
Thank you for considering contributing to SAM3 MLX! This document provides guidelines for contributing to the project.
|
| 4 |
+
|
| 5 |
+
## Code of Conduct
|
| 6 |
+
|
| 7 |
+
Be respectful and professional. We're all here to build great software together.
|
| 8 |
+
|
| 9 |
+
## How to Contribute
|
| 10 |
+
|
| 11 |
+
### Reporting Bugs
|
| 12 |
+
|
| 13 |
+
If you find a bug, please open an issue with:
|
| 14 |
+
- Clear description of the problem
|
| 15 |
+
- Steps to reproduce
|
| 16 |
+
- Expected vs actual behavior
|
| 17 |
+
- Your environment (Mac model, macOS version, MLX version)
|
| 18 |
+
- Error messages and stack traces
|
| 19 |
+
|
| 20 |
+
### Suggesting Features
|
| 21 |
+
|
| 22 |
+
Feature requests are welcome! Please include:
|
| 23 |
+
- Clear use case
|
| 24 |
+
- Why this feature would be useful
|
| 25 |
+
- How it might work
|
| 26 |
+
|
| 27 |
+
### Pull Requests
|
| 28 |
+
|
| 29 |
+
1. **Fork the repository**
|
| 30 |
+
```bash
|
| 31 |
+
git clone https://github.com/yourusername/sam3-mlx.git
|
| 32 |
+
cd sam3-mlx
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
2. **Create a branch**
|
| 36 |
+
```bash
|
| 37 |
+
git checkout -b feature/your-feature-name
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
3. **Make your changes**
|
| 41 |
+
- Write clear, documented code
|
| 42 |
+
- Follow the existing code style
|
| 43 |
+
- Add tests for new functionality
|
| 44 |
+
- Update documentation as needed
|
| 45 |
+
|
| 46 |
+
4. **Test your changes**
|
| 47 |
+
```bash
|
| 48 |
+
# Run tests
|
| 49 |
+
python tests/test_models.py
|
| 50 |
+
|
| 51 |
+
# Run benchmarks
|
| 52 |
+
python tests/benchmark.py
|
| 53 |
+
|
| 54 |
+
# Check code style
|
| 55 |
+
black sam3_mlx/
|
| 56 |
+
ruff check sam3_mlx/
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
5. **Commit and push**
|
| 60 |
+
```bash
|
| 61 |
+
git add .
|
| 62 |
+
git commit -m "Add feature: your feature description"
|
| 63 |
+
git push origin feature/your-feature-name
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
6. **Open a Pull Request**
|
| 67 |
+
- Describe what you changed and why
|
| 68 |
+
- Link any related issues
|
| 69 |
+
- Wait for review
|
| 70 |
+
|
| 71 |
+
## Development Setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# Clone the repository
|
| 75 |
+
git clone https://github.com/yourusername/sam3-mlx.git
|
| 76 |
+
cd sam3-mlx
|
| 77 |
+
|
| 78 |
+
# Install in development mode
|
| 79 |
+
pip install -e ".[dev]"
|
| 80 |
+
|
| 81 |
+
# Run tests
|
| 82 |
+
python tests/test_models.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Code Style
|
| 86 |
+
|
| 87 |
+
- **Python**: Follow PEP 8
|
| 88 |
+
- **Line length**: 100 characters
|
| 89 |
+
- **Formatting**: Use `black` for auto-formatting
|
| 90 |
+
- **Linting**: Use `ruff` for linting
|
| 91 |
+
- **Type hints**: Add type hints for function signatures
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
```python
|
| 95 |
+
def process_image(image: mx.array, size: int = 1024) -> mx.array:
|
| 96 |
+
"""
|
| 97 |
+
Process image for SAM3 input
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
image: Input image array
|
| 101 |
+
size: Target size
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Processed image
|
| 105 |
+
"""
|
| 106 |
+
# Implementation here
|
| 107 |
+
return processed_image
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Testing
|
| 111 |
+
|
| 112 |
+
- Add tests for all new features
|
| 113 |
+
- Maintain or improve code coverage
|
| 114 |
+
- Test on actual Apple Silicon hardware when possible
|
| 115 |
+
- Verify performance benchmarks don't regress
|
| 116 |
+
|
| 117 |
+
## Documentation
|
| 118 |
+
|
| 119 |
+
- Document all public functions and classes
|
| 120 |
+
- Update README.md for major changes
|
| 121 |
+
- Add examples for new features
|
| 122 |
+
- Keep docstrings up to date
|
| 123 |
+
|
| 124 |
+
## Performance
|
| 125 |
+
|
| 126 |
+
- Profile new code for performance
|
| 127 |
+
- Avoid unnecessary copies with MLX arrays
|
| 128 |
+
- Use MLX operations instead of numpy when possible
|
| 129 |
+
- Benchmark performance-critical changes
|
| 130 |
+
|
| 131 |
+
## Commit Messages
|
| 132 |
+
|
| 133 |
+
Write clear commit messages:
|
| 134 |
+
- Use present tense ("Add feature" not "Added feature")
|
| 135 |
+
- Keep first line under 72 characters
|
| 136 |
+
- Add detailed description if needed
|
| 137 |
+
|
| 138 |
+
Good examples:
|
| 139 |
+
```
|
| 140 |
+
Add RoPE attention implementation
|
| 141 |
+
|
| 142 |
+
Implements Rotary Position Embeddings for spatial awareness
|
| 143 |
+
in the vision transformer.
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
```
|
| 147 |
+
Fix memory leak in mask decoder
|
| 148 |
+
|
| 149 |
+
The transformer was not releasing intermediate tensors,
|
| 150 |
+
causing memory to grow with each inference.
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Release Process
|
| 154 |
+
|
| 155 |
+
Maintainers will:
|
| 156 |
+
1. Update version in `pyproject.toml`
|
| 157 |
+
2. Update CHANGELOG.md
|
| 158 |
+
3. Create a git tag
|
| 159 |
+
4. Publish to PyPI
|
| 160 |
+
|
| 161 |
+
## Questions?
|
| 162 |
+
|
| 163 |
+
Open an issue or start a discussion!
|
| 164 |
+
|
| 165 |
+
## License
|
| 166 |
+
|
| 167 |
+
By contributing, you agree that your contributions will be licensed under the MIT License.
|
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 SAM3 MLX Contributors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
This project implements Meta's Segment Anything Model 3 (SAM3) architecture.
|
| 26 |
+
The original SAM research and model architecture are from Meta AI Research.
|
| 27 |
+
Please see: https://segment-anything.com
|
| 28 |
+
|
| 29 |
+
SAM model weights are subject to Meta's license terms.
|
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM3 MLX Examples
|
| 2 |
+
|
| 3 |
+
Example scripts demonstrating how to use SAM3 MLX for segmentation tasks.
|
| 4 |
+
|
| 5 |
+
## Click-Based Segmentation
|
| 6 |
+
|
| 7 |
+
Segment objects by clicking on them with positive/negative points.
|
| 8 |
+
|
| 9 |
+
### Basic Usage
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Segment with a single positive click
|
| 13 |
+
python click_segment.py --image photo.jpg --point 512,384
|
| 14 |
+
|
| 15 |
+
# Segment with multiple points
|
| 16 |
+
python click_segment.py --image photo.jpg --point 512,384 --point 600,400
|
| 17 |
+
|
| 18 |
+
# Use positive (+) and negative (-) points for refinement
|
| 19 |
+
python click_segment.py --image photo.jpg --point +512,384 --point -100,100
|
| 20 |
+
|
| 21 |
+
# Save visualization
|
| 22 |
+
python click_segment.py --image photo.jpg --point 512,384 --output result.png
|
| 23 |
+
|
| 24 |
+
# Get single best mask instead of 3 masks
|
| 25 |
+
python click_segment.py --image photo.jpg --point 512,384 --single-mask
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Requirements
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
pip install pillow matplotlib mlx
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Performance
|
| 35 |
+
|
| 36 |
+
On Apple Silicon with MLX:
|
| 37 |
+
- Model initialization: ~2-3s
|
| 38 |
+
- Single inference: **<200ms** (target performance)
|
| 39 |
+
- Multiple masks: 3 predictions per inference
|
| 40 |
+
|
| 41 |
+
## Box-Based Segmentation
|
| 42 |
+
|
| 43 |
+
Coming soon: Segment using bounding box prompts.
|
| 44 |
+
|
| 45 |
+
## Mask-Based Refinement
|
| 46 |
+
|
| 47 |
+
Coming soon: Refine existing masks with additional mask prompts.
|
| 48 |
+
|
| 49 |
+
## Batch Processing
|
| 50 |
+
|
| 51 |
+
Coming soon: Process multiple images efficiently.
|
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM3 MLX Models
|
| 3 |
+
Complete implementation of SAM3 components in native MLX
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .attention import MultiHeadAttentionRoPE, WindowedAttention
|
| 7 |
+
from .hiera import HieraVisionEncoder, create_hiera_base, create_hiera_large
|
| 8 |
+
from .prompt_encoder import PromptEncoder, create_prompt_encoder
|
| 9 |
+
from .mask_decoder import MaskDecoder, create_mask_decoder
|
| 10 |
+
from .sam3 import SAM3MLX
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'MultiHeadAttentionRoPE',
|
| 14 |
+
'WindowedAttention',
|
| 15 |
+
'HieraVisionEncoder',
|
| 16 |
+
'create_hiera_base',
|
| 17 |
+
'create_hiera_large',
|
| 18 |
+
'PromptEncoder',
|
| 19 |
+
'create_prompt_encoder',
|
| 20 |
+
'MaskDecoder',
|
| 21 |
+
'create_mask_decoder',
|
| 22 |
+
'SAM3MLX',
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
__version__ = '0.1.0'
|
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RoPE Multi-Head Attention for SAM3
|
| 3 |
+
Implements Rotary Position Embeddings for spatial awareness
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import mlx.core as mx
|
| 7 |
+
import mlx.nn as nn
|
| 8 |
+
from mlx.nn import Module
|
| 9 |
+
import math
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
class RoPEEmbedding(Module):
|
| 13 |
+
"""Rotary Position Embedding - 2D version for images"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, dim: int, max_seq_len: int = 8192):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.dim = dim
|
| 18 |
+
|
| 19 |
+
# Precompute frequency matrix
|
| 20 |
+
inv_freq = 1.0 / (10000 ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
|
| 21 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 22 |
+
|
| 23 |
+
def forward(self, seq_len: int) -> mx.array:
|
| 24 |
+
"""Generate RoPE embeddings for given sequence length"""
|
| 25 |
+
# Generate position indices
|
| 26 |
+
t = mx.arange(seq_len, dtype=mx.float32)
|
| 27 |
+
|
| 28 |
+
# Compute frequencies: outer product of positions and inv_freq
|
| 29 |
+
freqs = mx.outer(t, self.inv_freq) # (seq_len, dim/2)
|
| 30 |
+
|
| 31 |
+
# Create sin and cos embeddings
|
| 32 |
+
emb = mx.concatenate([freqs, freqs], axis=-1) # (seq_len, dim)
|
| 33 |
+
|
| 34 |
+
return mx.stack([mx.cos(emb), mx.sin(emb)], axis=0) # (2, seq_len, dim)
|
| 35 |
+
|
| 36 |
+
def register_buffer(self, name: str, tensor: mx.array):
|
| 37 |
+
"""Register buffer (MLX doesn't need this, but keeping for compatibility)"""
|
| 38 |
+
setattr(self, name, tensor)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def apply_rotary_pos_emb(q: mx.array, k: mx.array, cos: mx.array, sin: mx.array) -> tuple:
|
| 42 |
+
"""
|
| 43 |
+
Apply rotary position embeddings to queries and keys
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
q: (batch, seq_len, num_heads, head_dim)
|
| 47 |
+
k: (batch, seq_len, num_heads, head_dim)
|
| 48 |
+
cos: (seq_len, head_dim)
|
| 49 |
+
sin: (seq_len, head_dim)
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Rotated q and k
|
| 53 |
+
"""
|
| 54 |
+
# Reshape for broadcasting
|
| 55 |
+
cos = cos.reshape(1, -1, 1, cos.shape[-1]) # (1, seq_len, 1, head_dim)
|
| 56 |
+
sin = sin.reshape(1, -1, 1, sin.shape[-1])
|
| 57 |
+
|
| 58 |
+
# Split into two halves for rotation
|
| 59 |
+
q_half1, q_half2 = mx.split(q, 2, axis=-1)
|
| 60 |
+
k_half1, k_half2 = mx.split(k, 2, axis=-1)
|
| 61 |
+
|
| 62 |
+
# Apply rotation
|
| 63 |
+
q_rotated = mx.concatenate([
|
| 64 |
+
q_half1 * cos - q_half2 * sin,
|
| 65 |
+
q_half1 * sin + q_half2 * cos
|
| 66 |
+
], axis=-1)
|
| 67 |
+
|
| 68 |
+
k_rotated = mx.concatenate([
|
| 69 |
+
k_half1 * cos - k_half2 * sin,
|
| 70 |
+
k_half1 * sin + k_half2 * cos
|
| 71 |
+
], axis=-1)
|
| 72 |
+
|
| 73 |
+
return q_rotated, k_rotated
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MultiHeadAttentionRoPE(Module):
|
| 77 |
+
"""
|
| 78 |
+
Multi-Head Attention with Rotary Position Embeddings
|
| 79 |
+
|
| 80 |
+
Key features:
|
| 81 |
+
- RoPE for relative position encoding
|
| 82 |
+
- Flash attention compatible
|
| 83 |
+
- Optimized for MLX/Metal
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
dim: int,
|
| 89 |
+
num_heads: int = 16,
|
| 90 |
+
qkv_bias: bool = True,
|
| 91 |
+
dropout: float = 0.0,
|
| 92 |
+
use_rope: bool = True
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
|
| 97 |
+
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.num_heads = num_heads
|
| 100 |
+
self.head_dim = dim // num_heads
|
| 101 |
+
self.scale = self.head_dim ** -0.5
|
| 102 |
+
self.use_rope = use_rope
|
| 103 |
+
|
| 104 |
+
# QKV projection
|
| 105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 106 |
+
|
| 107 |
+
# Output projection
|
| 108 |
+
self.proj = nn.Linear(dim, dim)
|
| 109 |
+
|
| 110 |
+
# Dropout
|
| 111 |
+
self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None
|
| 112 |
+
self.proj_dropout = nn.Dropout(dropout) if dropout > 0 else None
|
| 113 |
+
|
| 114 |
+
# RoPE
|
| 115 |
+
if use_rope:
|
| 116 |
+
self.rope = RoPEEmbedding(self.head_dim)
|
| 117 |
+
|
| 118 |
+
def forward(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
|
| 119 |
+
"""
|
| 120 |
+
Forward pass
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
x: (batch, seq_len, dim)
|
| 124 |
+
attn_mask: Optional attention mask
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Output: (batch, seq_len, dim)
|
| 128 |
+
"""
|
| 129 |
+
B, N, C = x.shape
|
| 130 |
+
|
| 131 |
+
# QKV projection and reshape
|
| 132 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 133 |
+
qkv = qkv.transpose(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
|
| 134 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 135 |
+
|
| 136 |
+
# Apply RoPE if enabled
|
| 137 |
+
if self.use_rope:
|
| 138 |
+
rope_emb = self.rope.forward(N) # (2, N, head_dim)
|
| 139 |
+
cos, sin = rope_emb[0], rope_emb[1]
|
| 140 |
+
|
| 141 |
+
# Transpose for apply_rotary: (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
|
| 142 |
+
q = q.transpose(0, 2, 1, 3)
|
| 143 |
+
k = k.transpose(0, 2, 1, 3)
|
| 144 |
+
|
| 145 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 146 |
+
|
| 147 |
+
# Transpose back
|
| 148 |
+
q = q.transpose(0, 2, 1, 3)
|
| 149 |
+
k = k.transpose(0, 2, 1, 3)
|
| 150 |
+
|
| 151 |
+
# Scaled dot-product attention
|
| 152 |
+
# q, k, v: (B, num_heads, N, head_dim)
|
| 153 |
+
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, num_heads, N, N)
|
| 154 |
+
|
| 155 |
+
# Apply attention mask if provided
|
| 156 |
+
if attn_mask is not None:
|
| 157 |
+
attn = attn + attn_mask
|
| 158 |
+
|
| 159 |
+
# Softmax
|
| 160 |
+
attn = mx.softmax(attn, axis=-1)
|
| 161 |
+
|
| 162 |
+
# Apply dropout
|
| 163 |
+
if self.attn_dropout is not None:
|
| 164 |
+
attn = self.attn_dropout(attn)
|
| 165 |
+
|
| 166 |
+
# Apply attention to values
|
| 167 |
+
x = attn @ v # (B, num_heads, N, head_dim)
|
| 168 |
+
|
| 169 |
+
# Reshape and project
|
| 170 |
+
x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
|
| 171 |
+
x = self.proj(x)
|
| 172 |
+
|
| 173 |
+
# Apply output dropout
|
| 174 |
+
if self.proj_dropout is not None:
|
| 175 |
+
x = self.proj_dropout(x)
|
| 176 |
+
|
| 177 |
+
return x
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class WindowedAttention(MultiHeadAttentionRoPE):
|
| 181 |
+
"""
|
| 182 |
+
Windowed Multi-Head Attention for local processing
|
| 183 |
+
Used in certain Hiera blocks for efficiency
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
dim: int,
|
| 189 |
+
num_heads: int = 16,
|
| 190 |
+
window_size: int = 14,
|
| 191 |
+
**kwargs
|
| 192 |
+
):
|
| 193 |
+
super().__init__(dim, num_heads, **kwargs)
|
| 194 |
+
self.window_size = window_size
|
| 195 |
+
|
| 196 |
+
def create_window_mask(self, seq_len: int) -> mx.array:
|
| 197 |
+
"""Create attention mask for windowed attention"""
|
| 198 |
+
# Create mask that only allows attention within window_size
|
| 199 |
+
mask = mx.ones((seq_len, seq_len)) * float('-inf')
|
| 200 |
+
|
| 201 |
+
for i in range(seq_len):
|
| 202 |
+
start = max(0, i - self.window_size // 2)
|
| 203 |
+
end = min(seq_len, i + self.window_size // 2 + 1)
|
| 204 |
+
mask[i, start:end] = 0.0
|
| 205 |
+
|
| 206 |
+
return mask.reshape(1, 1, seq_len, seq_len)
|
| 207 |
+
|
| 208 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 209 |
+
"""Forward with windowed attention"""
|
| 210 |
+
B, N, C = x.shape
|
| 211 |
+
|
| 212 |
+
# Create window mask
|
| 213 |
+
window_mask = self.create_window_mask(N)
|
| 214 |
+
|
| 215 |
+
return super().forward(x, attn_mask=window_mask)
|
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SAM3 MLX Benchmarks
|
| 4 |
+
|
| 5 |
+
Measures performance on Apple Silicon to validate <200ms target
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import numpy as np
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Add parent directory to path
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 16 |
+
|
| 17 |
+
from models.sam3 import SAM3MLX
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def benchmark_component(name: str, func, *args, warmup=3, iterations=10, **kwargs):
|
| 21 |
+
"""Benchmark a component with warmup"""
|
| 22 |
+
print(f"\n{'='*60}")
|
| 23 |
+
print(f"Benchmarking: {name}")
|
| 24 |
+
print(f"{'='*60}")
|
| 25 |
+
|
| 26 |
+
# Warmup
|
| 27 |
+
print(f"Warming up ({warmup} iterations)...")
|
| 28 |
+
for _ in range(warmup):
|
| 29 |
+
result = func(*args, **kwargs)
|
| 30 |
+
if isinstance(result, dict):
|
| 31 |
+
for v in result.values():
|
| 32 |
+
if isinstance(v, mx.array):
|
| 33 |
+
mx.eval(v)
|
| 34 |
+
elif isinstance(v, mx.array):
|
| 35 |
+
mx.eval(result)
|
| 36 |
+
|
| 37 |
+
# Benchmark
|
| 38 |
+
print(f"Running benchmark ({iterations} iterations)...")
|
| 39 |
+
times = []
|
| 40 |
+
|
| 41 |
+
for i in range(iterations):
|
| 42 |
+
start = time.time()
|
| 43 |
+
result = func(*args, **kwargs)
|
| 44 |
+
|
| 45 |
+
# Force evaluation
|
| 46 |
+
if isinstance(result, dict):
|
| 47 |
+
for v in result.values():
|
| 48 |
+
if isinstance(v, mx.array):
|
| 49 |
+
mx.eval(v)
|
| 50 |
+
elif isinstance(result, mx.array):
|
| 51 |
+
mx.eval(result)
|
| 52 |
+
|
| 53 |
+
elapsed = (time.time() - start) * 1000 # Convert to ms
|
| 54 |
+
times.append(elapsed)
|
| 55 |
+
print(f" Iteration {i+1}: {elapsed:.2f}ms")
|
| 56 |
+
|
| 57 |
+
# Statistics
|
| 58 |
+
times = np.array(times)
|
| 59 |
+
print(f"\nπ Results:")
|
| 60 |
+
print(f" Mean: {times.mean():.2f}ms")
|
| 61 |
+
print(f" Median: {np.median(times):.2f}ms")
|
| 62 |
+
print(f" Min: {times.min():.2f}ms")
|
| 63 |
+
print(f" Max: {times.max():.2f}ms")
|
| 64 |
+
print(f" Std: {times.std():.2f}ms")
|
| 65 |
+
|
| 66 |
+
return times.mean()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
print("π SAM3 MLX Performance Benchmarks")
|
| 71 |
+
print("=" * 60)
|
| 72 |
+
print(f"MLX version: {mx.__version__}")
|
| 73 |
+
print(f"Device: Apple Silicon (Metal)")
|
| 74 |
+
print("=" * 60)
|
| 75 |
+
|
| 76 |
+
# Initialize model
|
| 77 |
+
print("\nποΈ Initializing SAM3 MLX...")
|
| 78 |
+
model = SAM3MLX()
|
| 79 |
+
|
| 80 |
+
# Prepare inputs
|
| 81 |
+
print("\nπ¦ Preparing test inputs...")
|
| 82 |
+
image = mx.random.normal((1, 1024, 1024, 3))
|
| 83 |
+
point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
|
| 84 |
+
point_labels = mx.array([[1]]).astype(mx.float32)
|
| 85 |
+
|
| 86 |
+
# Benchmark components
|
| 87 |
+
results = {}
|
| 88 |
+
|
| 89 |
+
# 1. Vision Encoder
|
| 90 |
+
results['vision_encoder'] = benchmark_component(
|
| 91 |
+
"Vision Encoder (Hiera)",
|
| 92 |
+
model.encode_image,
|
| 93 |
+
image,
|
| 94 |
+
warmup=3,
|
| 95 |
+
iterations=10,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 2. Prompt Encoder
|
| 99 |
+
results['prompt_encoder'] = benchmark_component(
|
| 100 |
+
"Prompt Encoder",
|
| 101 |
+
model.prompt_encoder,
|
| 102 |
+
(point_coords, point_labels),
|
| 103 |
+
None,
|
| 104 |
+
None,
|
| 105 |
+
warmup=3,
|
| 106 |
+
iterations=20,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 3. Full Pipeline
|
| 110 |
+
results['full_pipeline'] = benchmark_component(
|
| 111 |
+
"Full Pipeline (encode + decode)",
|
| 112 |
+
model.predict,
|
| 113 |
+
image,
|
| 114 |
+
point_coords,
|
| 115 |
+
point_labels,
|
| 116 |
+
warmup=3,
|
| 117 |
+
iterations=10,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Summary
|
| 121 |
+
print(f"\n{'='*60}")
|
| 122 |
+
print(f"PERFORMANCE SUMMARY")
|
| 123 |
+
print(f"{'='*60}")
|
| 124 |
+
|
| 125 |
+
for component, avg_time in results.items():
|
| 126 |
+
status = "β
" if avg_time < 1000 else "β οΈ"
|
| 127 |
+
print(f"{status} {component:30s} {avg_time:8.2f}ms")
|
| 128 |
+
|
| 129 |
+
print(f"\n{'='*60}")
|
| 130 |
+
print(f"TARGET METRICS")
|
| 131 |
+
print(f"{'='*60}")
|
| 132 |
+
|
| 133 |
+
vision_target = 500 # ms
|
| 134 |
+
full_target = 200 # ms (after optimization)
|
| 135 |
+
|
| 136 |
+
vision_status = "β
PASS" if results['vision_encoder'] < vision_target else "β FAIL"
|
| 137 |
+
full_status = "π― TARGET" if results['full_pipeline'] < full_target else "β οΈ NEEDS OPTIMIZATION"
|
| 138 |
+
|
| 139 |
+
print(f"Vision Encoding: {vision_status} (target: <{vision_target}ms)")
|
| 140 |
+
print(f"Full Pipeline: {full_status} (target: <{full_target}ms)")
|
| 141 |
+
|
| 142 |
+
print(f"\n{'='*60}")
|
| 143 |
+
print("Benchmark complete!")
|
| 144 |
+
print(f"{'='*60}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SAM3 MLX Click Segmentation Example
|
| 4 |
+
|
| 5 |
+
Demonstrates how to:
|
| 6 |
+
1. Load SAM3 MLX model
|
| 7 |
+
2. Process an image
|
| 8 |
+
3. Segment objects with point clicks
|
| 9 |
+
4. Visualize results
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python click_segment.py --image path/to/image.jpg --point 100,200
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Tuple, Optional
|
| 19 |
+
import numpy as np
|
| 20 |
+
import mlx.core as mx
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from PIL import Image
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
except ImportError:
|
| 26 |
+
print("β Please install PIL and matplotlib:")
|
| 27 |
+
print(" pip install pillow matplotlib")
|
| 28 |
+
exit(1)
|
| 29 |
+
|
| 30 |
+
# Add parent directory to path
|
| 31 |
+
import sys
|
| 32 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 33 |
+
|
| 34 |
+
from models.sam3 import SAM3MLX
|
| 35 |
+
from utils.weights import load_weights
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_image(image_path: str, target_size: int = 1024) -> Tuple[mx.array, np.ndarray]:
|
| 39 |
+
"""
|
| 40 |
+
Load and preprocess image for SAM3
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
image_path: Path to image file
|
| 44 |
+
target_size: Target image size (SAM3 uses 1024x1024)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple of (preprocessed MLX array, original numpy array)
|
| 48 |
+
"""
|
| 49 |
+
# Load image
|
| 50 |
+
img = Image.open(image_path).convert("RGB")
|
| 51 |
+
original = np.array(img)
|
| 52 |
+
|
| 53 |
+
# Resize to target size
|
| 54 |
+
img_resized = img.resize((target_size, target_size), Image.BILINEAR)
|
| 55 |
+
img_np = np.array(img_resized).astype(np.float32) / 255.0
|
| 56 |
+
|
| 57 |
+
# Convert to MLX array in NHWC format
|
| 58 |
+
img_mlx = mx.array(img_np).reshape(1, target_size, target_size, 3)
|
| 59 |
+
|
| 60 |
+
return img_mlx, original
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def visualize_prediction(
|
| 64 |
+
image: np.ndarray,
|
| 65 |
+
masks: mx.array,
|
| 66 |
+
point_coords: mx.array,
|
| 67 |
+
point_labels: mx.array,
|
| 68 |
+
iou_scores: mx.array,
|
| 69 |
+
save_path: Optional[str] = None,
|
| 70 |
+
):
|
| 71 |
+
"""
|
| 72 |
+
Visualize segmentation results
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
image: Original image (H, W, 3)
|
| 76 |
+
masks: Predicted masks (1, num_masks, H, W)
|
| 77 |
+
point_coords: Input point coordinates (1, N, 2)
|
| 78 |
+
point_labels: Input point labels (1, N)
|
| 79 |
+
iou_scores: IoU quality scores (1, num_masks)
|
| 80 |
+
save_path: Optional path to save visualization
|
| 81 |
+
"""
|
| 82 |
+
# Convert MLX to numpy
|
| 83 |
+
masks_np = np.array(masks[0]) # (num_masks, H, W)
|
| 84 |
+
point_coords_np = np.array(point_coords[0]) # (N, 2)
|
| 85 |
+
point_labels_np = np.array(point_labels[0]) # (N,)
|
| 86 |
+
iou_scores_np = np.array(iou_scores[0]) # (num_masks,)
|
| 87 |
+
|
| 88 |
+
num_masks = masks_np.shape[0]
|
| 89 |
+
|
| 90 |
+
# Create figure
|
| 91 |
+
fig, axes = plt.subplots(1, num_masks + 1, figsize=(5 * (num_masks + 1), 5))
|
| 92 |
+
if num_masks == 1:
|
| 93 |
+
axes = [axes[0], axes[1]]
|
| 94 |
+
|
| 95 |
+
# Show original image with points
|
| 96 |
+
axes[0].imshow(image)
|
| 97 |
+
axes[0].set_title("Input Image with Points")
|
| 98 |
+
|
| 99 |
+
# Plot positive points (green) and negative points (red)
|
| 100 |
+
for coord, label in zip(point_coords_np, point_labels_np):
|
| 101 |
+
color = 'g' if label == 1 else 'r'
|
| 102 |
+
marker = 'o' if label == 1 else 'x'
|
| 103 |
+
axes[0].scatter(coord[0], coord[1], c=color, marker=marker, s=200, linewidths=3)
|
| 104 |
+
|
| 105 |
+
axes[0].axis('off')
|
| 106 |
+
|
| 107 |
+
# Show each predicted mask
|
| 108 |
+
for i in range(num_masks):
|
| 109 |
+
# Resize mask to original image size
|
| 110 |
+
mask = masks_np[i]
|
| 111 |
+
H, W = image.shape[:2]
|
| 112 |
+
from PIL import Image as PILImage
|
| 113 |
+
mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8))
|
| 114 |
+
mask_resized = mask_resized.resize((W, H), PILImage.BILINEAR)
|
| 115 |
+
mask_resized = np.array(mask_resized) / 255.0
|
| 116 |
+
|
| 117 |
+
# Overlay mask on image
|
| 118 |
+
overlay = image.copy()
|
| 119 |
+
mask_3ch = np.stack([mask_resized] * 3, axis=-1)
|
| 120 |
+
overlay = (overlay * (1 - mask_3ch * 0.5) + np.array([0, 255, 0]) * mask_3ch * 0.5).astype(np.uint8)
|
| 121 |
+
|
| 122 |
+
axes[i + 1].imshow(overlay)
|
| 123 |
+
axes[i + 1].set_title(f"Mask {i+1} (IoU: {iou_scores_np[i]:.3f})")
|
| 124 |
+
axes[i + 1].axis('off')
|
| 125 |
+
|
| 126 |
+
plt.tight_layout()
|
| 127 |
+
|
| 128 |
+
if save_path:
|
| 129 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=150)
|
| 130 |
+
print(f"πΎ Saved visualization to {save_path}")
|
| 131 |
+
|
| 132 |
+
plt.show()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def main():
|
| 136 |
+
parser = argparse.ArgumentParser(description="SAM3 MLX Click Segmentation Example")
|
| 137 |
+
parser.add_argument("--image", type=str, required=True, help="Path to input image")
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--point",
|
| 140 |
+
type=str,
|
| 141 |
+
action="append",
|
| 142 |
+
help="Click point as 'x,y' (can specify multiple). Use +x,y for positive, -x,y for negative",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--checkpoint",
|
| 146 |
+
type=str,
|
| 147 |
+
default="./checkpoints/sam3_mlx",
|
| 148 |
+
help="Path to SAM3 MLX checkpoint directory",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--output",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
help="Path to save output visualization",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--single-mask",
|
| 158 |
+
action="store_true",
|
| 159 |
+
help="Output single mask instead of 3 masks",
|
| 160 |
+
)
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
print("π SAM3 MLX Click Segmentation Example")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
# Parse points
|
| 167 |
+
if not args.point:
|
| 168 |
+
print("β Please specify at least one point with --point x,y")
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
point_coords_list = []
|
| 172 |
+
point_labels_list = []
|
| 173 |
+
|
| 174 |
+
for point_str in args.point:
|
| 175 |
+
# Check for label prefix
|
| 176 |
+
if point_str.startswith('+'):
|
| 177 |
+
label = 1 # Positive
|
| 178 |
+
point_str = point_str[1:]
|
| 179 |
+
elif point_str.startswith('-'):
|
| 180 |
+
label = 0 # Negative
|
| 181 |
+
point_str = point_str[1:]
|
| 182 |
+
else:
|
| 183 |
+
label = 1 # Default to positive
|
| 184 |
+
|
| 185 |
+
x, y = map(float, point_str.split(','))
|
| 186 |
+
point_coords_list.append([x, y])
|
| 187 |
+
point_labels_list.append(label)
|
| 188 |
+
|
| 189 |
+
point_coords = mx.array(point_coords_list).reshape(1, -1, 2)
|
| 190 |
+
point_labels = mx.array(point_labels_list).reshape(1, -1)
|
| 191 |
+
|
| 192 |
+
print(f"π Input points: {len(point_coords_list)}")
|
| 193 |
+
for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)):
|
| 194 |
+
label_str = "positive" if label == 1 else "negative"
|
| 195 |
+
print(f" Point {i+1}: ({coord[0]:.0f}, {coord[1]:.0f}) [{label_str}]")
|
| 196 |
+
|
| 197 |
+
# Load image
|
| 198 |
+
print(f"\nπΈ Loading image: {args.image}")
|
| 199 |
+
image_mlx, image_original = load_image(args.image)
|
| 200 |
+
print(f" Image size: {image_original.shape[1]}x{image_original.shape[0]}")
|
| 201 |
+
|
| 202 |
+
# Initialize model
|
| 203 |
+
print(f"\nποΈ Initializing SAM3 MLX model...")
|
| 204 |
+
model = SAM3MLX()
|
| 205 |
+
|
| 206 |
+
# Load weights if available
|
| 207 |
+
checkpoint_dir = Path(args.checkpoint)
|
| 208 |
+
weights_path = checkpoint_dir / "sam3_mlx_weights.npz"
|
| 209 |
+
|
| 210 |
+
if weights_path.exists():
|
| 211 |
+
print(f"\nπ₯ Loading weights from {checkpoint_dir}")
|
| 212 |
+
model = load_weights(model, str(weights_path), strict=False, verbose=True)
|
| 213 |
+
else:
|
| 214 |
+
print(f"\nβ οΈ Weights not found at {weights_path}")
|
| 215 |
+
print(" Using randomly initialized model (for testing architecture only)")
|
| 216 |
+
|
| 217 |
+
# Run inference
|
| 218 |
+
print(f"\nπ― Running segmentation...")
|
| 219 |
+
start_time = time.time()
|
| 220 |
+
|
| 221 |
+
result = model.predict(
|
| 222 |
+
image=image_mlx,
|
| 223 |
+
point_coords=point_coords,
|
| 224 |
+
point_labels=point_labels,
|
| 225 |
+
multimask_output=not args.single_mask,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Ensure computation is complete
|
| 229 |
+
mx.eval(result["masks"])
|
| 230 |
+
|
| 231 |
+
inference_time = (time.time() - start_time) * 1000
|
| 232 |
+
print(f"β
Inference completed in {inference_time:.1f}ms")
|
| 233 |
+
|
| 234 |
+
# Print results
|
| 235 |
+
masks = result["masks"]
|
| 236 |
+
iou_predictions = result["iou_predictions"]
|
| 237 |
+
|
| 238 |
+
print(f"\nπ Results:")
|
| 239 |
+
print(f" Number of masks: {masks.shape[1]}")
|
| 240 |
+
print(f" Mask resolution: {masks.shape[2]}x{masks.shape[3]}")
|
| 241 |
+
print(f" IoU scores: {np.array(iou_predictions[0])}")
|
| 242 |
+
|
| 243 |
+
# Visualize
|
| 244 |
+
print(f"\nπ¨ Visualizing results...")
|
| 245 |
+
visualize_prediction(
|
| 246 |
+
image_original,
|
| 247 |
+
masks,
|
| 248 |
+
point_coords,
|
| 249 |
+
point_labels,
|
| 250 |
+
iou_predictions,
|
| 251 |
+
save_path=args.output,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
print(f"\nβ
Done!")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
main()
|
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hiera (Hierarchical Vision Transformer) - Complete MLX Implementation
|
| 3 |
+
|
| 4 |
+
This is the vision backbone used in SAM3, featuring:
|
| 5 |
+
- Multi-scale hierarchical processing
|
| 6 |
+
- Stage-wise spatial pooling
|
| 7 |
+
- RoPE attention at each scale
|
| 8 |
+
- Efficient computation via MLX/Metal
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import mlx.core as mx
|
| 12 |
+
import mlx.nn as nn
|
| 13 |
+
from mlx.nn import Module
|
| 14 |
+
from typing import List, Optional, Tuple
|
| 15 |
+
from .attention import MultiHeadAttentionRoPE, WindowedAttention
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MLP(Module):
|
| 19 |
+
"""
|
| 20 |
+
Multi-Layer Perceptron with GELU activation
|
| 21 |
+
Standard FFN block in transformers
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 27 |
+
self.act = nn.GELU()
|
| 28 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 29 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
| 30 |
+
|
| 31 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 32 |
+
x = self.fc1(x)
|
| 33 |
+
x = self.act(x)
|
| 34 |
+
if self.dropout:
|
| 35 |
+
x = self.dropout(x)
|
| 36 |
+
x = self.fc2(x)
|
| 37 |
+
if self.dropout:
|
| 38 |
+
x = self.dropout(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class HieraBlock(Module):
|
| 43 |
+
"""
|
| 44 |
+
Single Hiera transformer block
|
| 45 |
+
|
| 46 |
+
Features:
|
| 47 |
+
- Pre-LayerNorm architecture
|
| 48 |
+
- RoPE Multi-Head Attention
|
| 49 |
+
- MLP with GELU
|
| 50 |
+
- Residual connections
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
dim: int,
|
| 56 |
+
num_heads: int,
|
| 57 |
+
mlp_ratio: float = 4.0,
|
| 58 |
+
qkv_bias: bool = True,
|
| 59 |
+
dropout: float = 0.0,
|
| 60 |
+
use_windowed_attn: bool = False,
|
| 61 |
+
window_size: int = 14,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 66 |
+
|
| 67 |
+
# Choose attention type
|
| 68 |
+
if use_windowed_attn:
|
| 69 |
+
self.attn = WindowedAttention(
|
| 70 |
+
dim,
|
| 71 |
+
num_heads=num_heads,
|
| 72 |
+
qkv_bias=qkv_bias,
|
| 73 |
+
dropout=dropout,
|
| 74 |
+
window_size=window_size
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
self.attn = MultiHeadAttentionRoPE(
|
| 78 |
+
dim,
|
| 79 |
+
num_heads=num_heads,
|
| 80 |
+
qkv_bias=qkv_bias,
|
| 81 |
+
dropout=dropout
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 85 |
+
self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 88 |
+
# Attention with pre-norm and residual
|
| 89 |
+
x = x + self.attn(self.norm1(x))
|
| 90 |
+
|
| 91 |
+
# MLP with pre-norm and residual
|
| 92 |
+
x = x + self.mlp(self.norm2(x))
|
| 93 |
+
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class PatchEmbed(Module):
|
| 98 |
+
"""
|
| 99 |
+
Image to Patch Embedding using Conv2d
|
| 100 |
+
|
| 101 |
+
Converts (B, H, W, C) image to (B, num_patches, embed_dim) patches
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
img_size: int = 1024,
|
| 107 |
+
patch_size: int = 14,
|
| 108 |
+
in_chans: int = 3,
|
| 109 |
+
embed_dim: int = 1024
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.img_size = img_size
|
| 113 |
+
self.patch_size = patch_size
|
| 114 |
+
self.grid_size = img_size // patch_size
|
| 115 |
+
self.num_patches = self.grid_size ** 2
|
| 116 |
+
|
| 117 |
+
# Convolution for patch embedding
|
| 118 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 119 |
+
|
| 120 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 121 |
+
"""
|
| 122 |
+
Args:
|
| 123 |
+
x: (B, H, W, C) in NHWC format (MLX convention)
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
(B, num_patches, embed_dim)
|
| 127 |
+
"""
|
| 128 |
+
B, H, W, C = x.shape
|
| 129 |
+
|
| 130 |
+
# Apply convolution
|
| 131 |
+
x = self.proj(x) # (B, H', W', embed_dim) where H'=W'=grid_size
|
| 132 |
+
|
| 133 |
+
# Flatten spatial dimensions
|
| 134 |
+
B, H_p, W_p, C_emb = x.shape
|
| 135 |
+
x = x.reshape(B, H_p * W_p, C_emb) # (B, num_patches, embed_dim)
|
| 136 |
+
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DownsampleBlock(Module):
|
| 141 |
+
"""
|
| 142 |
+
Spatial downsampling block for hierarchical processing
|
| 143 |
+
|
| 144 |
+
Reduces spatial resolution by 2x while increasing channels
|
| 145 |
+
Uses depthwise-separable convolution for efficiency
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self, in_dim: int, out_dim: int):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
# Depthwise convolution (2x2 pooling with stride 2)
|
| 152 |
+
self.dw_conv = nn.Conv2d(in_dim, in_dim, kernel_size=2, stride=2, groups=in_dim)
|
| 153 |
+
|
| 154 |
+
# Pointwise convolution (1x1 to change channels)
|
| 155 |
+
self.pw_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 156 |
+
|
| 157 |
+
self.norm = nn.LayerNorm(out_dim)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: mx.array, h: int, w: int) -> Tuple[mx.array, int, int]:
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
x: (B, N, C) where N = h*w
|
| 163 |
+
h, w: Spatial dimensions
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
(B, N//4, C'), h//2, w//2
|
| 167 |
+
"""
|
| 168 |
+
B, N, C = x.shape
|
| 169 |
+
|
| 170 |
+
# Reshape to spatial format: (B, N, C) -> (B, h, w, C)
|
| 171 |
+
x = x.reshape(B, h, w, C)
|
| 172 |
+
|
| 173 |
+
# Apply convolutions
|
| 174 |
+
x = self.dw_conv(x)
|
| 175 |
+
x = self.pw_conv(x)
|
| 176 |
+
|
| 177 |
+
# Flatten back: (B, h//2, w//2, out_dim) -> (B, N//4, out_dim)
|
| 178 |
+
B, h_new, w_new, C_new = x.shape
|
| 179 |
+
x = x.reshape(B, h_new * w_new, C_new)
|
| 180 |
+
|
| 181 |
+
# Normalize
|
| 182 |
+
x = self.norm(x)
|
| 183 |
+
|
| 184 |
+
return x, h_new, w_new
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class HieraStage(Module):
|
| 188 |
+
"""
|
| 189 |
+
Single stage of Hiera with multiple blocks
|
| 190 |
+
|
| 191 |
+
Each stage processes at a specific spatial scale
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
dim: int,
|
| 197 |
+
depth: int,
|
| 198 |
+
num_heads: int,
|
| 199 |
+
mlp_ratio: float = 4.0,
|
| 200 |
+
use_windowed_attn: bool = False,
|
| 201 |
+
window_size: int = 14,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
self.blocks = [
|
| 206 |
+
HieraBlock(
|
| 207 |
+
dim=dim,
|
| 208 |
+
num_heads=num_heads,
|
| 209 |
+
mlp_ratio=mlp_ratio,
|
| 210 |
+
use_windowed_attn=use_windowed_attn and (i % 2 == 0), # Alternate global/local
|
| 211 |
+
window_size=window_size
|
| 212 |
+
)
|
| 213 |
+
for i in range(depth)
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 217 |
+
for block in self.blocks:
|
| 218 |
+
x = block(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class HieraVisionEncoder(Module):
|
| 223 |
+
"""
|
| 224 |
+
Complete Hiera Vision Encoder
|
| 225 |
+
|
| 226 |
+
Multi-scale hierarchical vision transformer with:
|
| 227 |
+
- 4 stages with increasing channel dimensions
|
| 228 |
+
- Spatial downsampling between stages
|
| 229 |
+
- RoPE attention at all scales
|
| 230 |
+
- Both global and windowed attention
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
img_size: Input image size
|
| 234 |
+
patch_size: Initial patch size
|
| 235 |
+
in_chans: Input channels (3 for RGB)
|
| 236 |
+
embed_dims: Channel dimensions for each stage
|
| 237 |
+
depths: Number of blocks per stage
|
| 238 |
+
num_heads: Attention heads per stage
|
| 239 |
+
mlp_ratio: MLP hidden dim ratio
|
| 240 |
+
use_windowed_attn: Use windowed attention in stages
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
img_size: int = 1024,
|
| 246 |
+
patch_size: int = 14,
|
| 247 |
+
in_chans: int = 3,
|
| 248 |
+
embed_dims: List[int] = [256, 512, 1024, 1024], # Progressive channel increase
|
| 249 |
+
depths: List[int] = [2, 8, 16, 6], # Blocks per stage
|
| 250 |
+
num_heads: List[int] = [4, 8, 16, 16],
|
| 251 |
+
mlp_ratio: float = 4.0,
|
| 252 |
+
use_windowed_attn: bool = True,
|
| 253 |
+
window_size: int = 14,
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
|
| 257 |
+
assert len(embed_dims) == len(depths) == len(num_heads), \
|
| 258 |
+
"embed_dims, depths, and num_heads must have same length"
|
| 259 |
+
|
| 260 |
+
self.num_stages = len(embed_dims)
|
| 261 |
+
self.patch_size = patch_size
|
| 262 |
+
|
| 263 |
+
# Patch embedding
|
| 264 |
+
self.patch_embed = PatchEmbed(
|
| 265 |
+
img_size=img_size,
|
| 266 |
+
patch_size=patch_size,
|
| 267 |
+
in_chans=in_chans,
|
| 268 |
+
embed_dim=embed_dims[0]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Initial spatial dimensions
|
| 272 |
+
self.init_h = self.init_w = img_size // patch_size
|
| 273 |
+
|
| 274 |
+
# Pre-norm before stages
|
| 275 |
+
self.norm_pre = nn.LayerNorm(embed_dims[0])
|
| 276 |
+
|
| 277 |
+
# Build stages
|
| 278 |
+
self.stages = []
|
| 279 |
+
self.downsample_layers = []
|
| 280 |
+
|
| 281 |
+
for i in range(self.num_stages):
|
| 282 |
+
# Create stage
|
| 283 |
+
stage = HieraStage(
|
| 284 |
+
dim=embed_dims[i],
|
| 285 |
+
depth=depths[i],
|
| 286 |
+
num_heads=num_heads[i],
|
| 287 |
+
mlp_ratio=mlp_ratio,
|
| 288 |
+
use_windowed_attn=use_windowed_attn,
|
| 289 |
+
window_size=window_size
|
| 290 |
+
)
|
| 291 |
+
self.stages.append(stage)
|
| 292 |
+
|
| 293 |
+
# Create downsampling layer (except for last stage)
|
| 294 |
+
if i < self.num_stages - 1:
|
| 295 |
+
downsample = DownsampleBlock(embed_dims[i], embed_dims[i + 1])
|
| 296 |
+
self.downsample_layers.append(downsample)
|
| 297 |
+
|
| 298 |
+
# Final norm
|
| 299 |
+
self.norm = nn.LayerNorm(embed_dims[-1])
|
| 300 |
+
|
| 301 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
x: (B, H, W, C) image in NHWC format
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
(B, num_patches_final, embed_dim_final) features
|
| 308 |
+
"""
|
| 309 |
+
# Patch embedding
|
| 310 |
+
x = self.patch_embed(x) # (B, num_patches, embed_dim[0])
|
| 311 |
+
|
| 312 |
+
# Pre-norm
|
| 313 |
+
x = self.norm_pre(x)
|
| 314 |
+
|
| 315 |
+
# Track spatial dimensions
|
| 316 |
+
h, w = self.init_h, self.init_w
|
| 317 |
+
|
| 318 |
+
# Process through stages
|
| 319 |
+
for i, stage in enumerate(self.stages):
|
| 320 |
+
# Apply stage
|
| 321 |
+
x = stage(x)
|
| 322 |
+
|
| 323 |
+
# Downsample (except last stage)
|
| 324 |
+
if i < len(self.downsample_layers):
|
| 325 |
+
x, h, w = self.downsample_layers[i](x, h, w)
|
| 326 |
+
|
| 327 |
+
# Final norm
|
| 328 |
+
x = self.norm(x)
|
| 329 |
+
|
| 330 |
+
return x
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def create_hiera_base() -> HieraVisionEncoder:
|
| 334 |
+
"""Create Hiera-Base configuration (SAM3 default)"""
|
| 335 |
+
return HieraVisionEncoder(
|
| 336 |
+
img_size=1024,
|
| 337 |
+
patch_size=14,
|
| 338 |
+
embed_dims=[256, 512, 1024, 1024],
|
| 339 |
+
depths=[2, 8, 16, 6],
|
| 340 |
+
num_heads=[4, 8, 16, 16]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def create_hiera_large() -> HieraVisionEncoder:
|
| 345 |
+
"""Create Hiera-Large configuration"""
|
| 346 |
+
return HieraVisionEncoder(
|
| 347 |
+
img_size=1024,
|
| 348 |
+
patch_size=14,
|
| 349 |
+
embed_dims=[384, 768, 1536, 1536],
|
| 350 |
+
depths=[2, 8, 20, 8],
|
| 351 |
+
num_heads=[6, 12, 24, 24]
|
| 352 |
+
)
|
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM3 Mask Decoder - Complete MLX Implementation
|
| 3 |
+
|
| 4 |
+
Predicts high-resolution segmentation masks from:
|
| 5 |
+
- Image embeddings (from Hiera vision encoder)
|
| 6 |
+
- Prompt embeddings (from prompt encoder)
|
| 7 |
+
|
| 8 |
+
Architecture:
|
| 9 |
+
1. Transformer decoder with cross-attention to image features
|
| 10 |
+
2. Dynamic mask prediction head
|
| 11 |
+
3. IoU quality prediction
|
| 12 |
+
4. Multi-mask output (3 masks + confidence scores)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import mlx.core as mx
|
| 16 |
+
import mlx.nn as nn
|
| 17 |
+
from mlx.nn import Module
|
| 18 |
+
from typing import Tuple, List
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLPBlock(Module):
|
| 22 |
+
"""
|
| 23 |
+
Simple MLP block with one hidden layer
|
| 24 |
+
Used in transformer and prediction heads
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
embedding_dim: int,
|
| 30 |
+
mlp_dim: int,
|
| 31 |
+
activation=nn.GELU
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
| 35 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
| 36 |
+
self.act = activation()
|
| 37 |
+
|
| 38 |
+
def forward(self, x: mx.array) -> mx.array:
|
| 39 |
+
return self.lin2(self.act(self.lin1(x)))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TwoWayAttentionBlock(Module):
|
| 43 |
+
"""
|
| 44 |
+
Two-way cross-attention transformer block
|
| 45 |
+
|
| 46 |
+
Performs:
|
| 47 |
+
1. Self-attention on queries (prompts)
|
| 48 |
+
2. Cross-attention from queries to keys (image features)
|
| 49 |
+
3. MLP on queries
|
| 50 |
+
4. Cross-attention from keys to queries
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
embedding_dim: int,
|
| 56 |
+
num_heads: int = 8,
|
| 57 |
+
mlp_dim: int = 2048,
|
| 58 |
+
activation=nn.GELU,
|
| 59 |
+
skip_first_layer_pe: bool = False,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.self_attn = nn.MultiHeadAttention(embedding_dim, num_heads)
|
| 63 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 64 |
+
|
| 65 |
+
self.cross_attn_token_to_image = nn.MultiHeadAttention(
|
| 66 |
+
embedding_dim, num_heads // 2
|
| 67 |
+
)
|
| 68 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 69 |
+
|
| 70 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
| 71 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 72 |
+
|
| 73 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 74 |
+
self.cross_attn_image_to_token = nn.MultiHeadAttention(
|
| 75 |
+
embedding_dim, num_heads // 2
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 79 |
+
|
| 80 |
+
def forward(
|
| 81 |
+
self,
|
| 82 |
+
queries: mx.array,
|
| 83 |
+
keys: mx.array,
|
| 84 |
+
query_pe: mx.array,
|
| 85 |
+
key_pe: mx.array,
|
| 86 |
+
) -> Tuple[mx.array, mx.array]:
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
queries: (B, N_q, C) prompt tokens
|
| 90 |
+
keys: (B, N_k, C) image tokens
|
| 91 |
+
query_pe: (B, N_q, C) positional encoding for queries
|
| 92 |
+
key_pe: (B, N_k, C) positional encoding for keys
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Updated queries and keys
|
| 96 |
+
"""
|
| 97 |
+
# Self-attention on queries
|
| 98 |
+
if self.skip_first_layer_pe:
|
| 99 |
+
queries = self.self_attn(queries, queries, queries)
|
| 100 |
+
else:
|
| 101 |
+
q = queries + query_pe
|
| 102 |
+
queries = self.self_attn(q, q, queries)
|
| 103 |
+
queries = self.norm1(queries)
|
| 104 |
+
|
| 105 |
+
# Cross-attention: queries -> image
|
| 106 |
+
q = queries + query_pe
|
| 107 |
+
k = keys + key_pe
|
| 108 |
+
queries = queries + self.cross_attn_token_to_image(q, k, keys)
|
| 109 |
+
queries = self.norm2(queries)
|
| 110 |
+
|
| 111 |
+
# MLP
|
| 112 |
+
queries = queries + self.mlp(queries)
|
| 113 |
+
queries = self.norm3(queries)
|
| 114 |
+
|
| 115 |
+
# Cross-attention: image -> queries
|
| 116 |
+
q = queries + query_pe
|
| 117 |
+
k = keys + key_pe
|
| 118 |
+
keys = keys + self.cross_attn_image_to_token(k, q, queries)
|
| 119 |
+
keys = self.norm4(keys)
|
| 120 |
+
|
| 121 |
+
return queries, keys
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TwoWayTransformer(Module):
|
| 125 |
+
"""
|
| 126 |
+
Two-way transformer decoder
|
| 127 |
+
|
| 128 |
+
Processes sparse prompts and dense image features
|
| 129 |
+
to produce mask predictions
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
depth: int,
|
| 135 |
+
embedding_dim: int,
|
| 136 |
+
num_heads: int,
|
| 137 |
+
mlp_dim: int,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.depth = depth
|
| 141 |
+
self.embedding_dim = embedding_dim
|
| 142 |
+
|
| 143 |
+
# Stack of two-way attention blocks
|
| 144 |
+
self.layers = [
|
| 145 |
+
TwoWayAttentionBlock(
|
| 146 |
+
embedding_dim=embedding_dim,
|
| 147 |
+
num_heads=num_heads,
|
| 148 |
+
mlp_dim=mlp_dim,
|
| 149 |
+
skip_first_layer_pe=(i == 0),
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
self.final_attn_token_to_image = nn.MultiHeadAttention(
|
| 155 |
+
embedding_dim, num_heads
|
| 156 |
+
)
|
| 157 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 158 |
+
|
| 159 |
+
def forward(
|
| 160 |
+
self,
|
| 161 |
+
image_embedding: mx.array,
|
| 162 |
+
image_pe: mx.array,
|
| 163 |
+
point_embedding: mx.array,
|
| 164 |
+
) -> Tuple[mx.array, mx.array]:
|
| 165 |
+
"""
|
| 166 |
+
Args:
|
| 167 |
+
image_embedding: (B, H*W, C) image features
|
| 168 |
+
image_pe: (B, H*W, C) positional encoding for image
|
| 169 |
+
point_embedding: (B, N, C) prompt embeddings
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Updated tokens and image features
|
| 173 |
+
"""
|
| 174 |
+
# Prepare queries (prompts) and keys (image)
|
| 175 |
+
queries = point_embedding
|
| 176 |
+
keys = image_embedding
|
| 177 |
+
|
| 178 |
+
# Pass through transformer layers
|
| 179 |
+
for layer in self.layers:
|
| 180 |
+
queries, keys = layer(
|
| 181 |
+
queries=queries,
|
| 182 |
+
keys=keys,
|
| 183 |
+
query_pe=point_embedding,
|
| 184 |
+
key_pe=image_pe,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Final attention from prompts to image
|
| 188 |
+
q = queries + point_embedding
|
| 189 |
+
k = keys + image_pe
|
| 190 |
+
queries = queries + self.final_attn_token_to_image(q, k, keys)
|
| 191 |
+
queries = self.norm_final_attn(queries)
|
| 192 |
+
|
| 193 |
+
return queries, keys
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MaskDecoder(Module):
|
| 197 |
+
"""
|
| 198 |
+
Complete SAM3 Mask Decoder
|
| 199 |
+
|
| 200 |
+
Predicts segmentation masks from image and prompt embeddings.
|
| 201 |
+
Outputs multiple masks with quality scores.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
transformer_dim: Channel dimension of transformer
|
| 205 |
+
transformer: Two-way transformer for mask prediction
|
| 206 |
+
num_multimask_outputs: Number of masks to predict (default 3)
|
| 207 |
+
iou_head_depth: Depth of IoU prediction MLP
|
| 208 |
+
iou_head_hidden_dim: Hidden dim for IoU MLP
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
transformer_dim: int = 256,
|
| 214 |
+
transformer_depth: int = 2,
|
| 215 |
+
transformer_num_heads: int = 8,
|
| 216 |
+
transformer_mlp_dim: int = 2048,
|
| 217 |
+
num_multimask_outputs: int = 3,
|
| 218 |
+
iou_head_depth: int = 3,
|
| 219 |
+
iou_head_hidden_dim: int = 256,
|
| 220 |
+
):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.transformer_dim = transformer_dim
|
| 223 |
+
self.num_multimask_outputs = num_multimask_outputs
|
| 224 |
+
|
| 225 |
+
# Two-way transformer
|
| 226 |
+
self.transformer = TwoWayTransformer(
|
| 227 |
+
depth=transformer_depth,
|
| 228 |
+
embedding_dim=transformer_dim,
|
| 229 |
+
num_heads=transformer_num_heads,
|
| 230 |
+
mlp_dim=transformer_mlp_dim,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# IoU prediction head
|
| 234 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
| 235 |
+
|
| 236 |
+
# Mask tokens for multi-mask prediction
|
| 237 |
+
self.num_mask_tokens = num_multimask_outputs + 1 # +1 for single mask
|
| 238 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
| 239 |
+
|
| 240 |
+
# Output upscaling layers
|
| 241 |
+
# Upsample from 64x64 -> 256x256 (4x upsampling)
|
| 242 |
+
self.output_upscaling = nn.Sequential(
|
| 243 |
+
nn.ConvTranspose2d(
|
| 244 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
| 245 |
+
),
|
| 246 |
+
nn.LayerNorm(transformer_dim // 4),
|
| 247 |
+
nn.GELU(),
|
| 248 |
+
nn.ConvTranspose2d(
|
| 249 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
| 250 |
+
),
|
| 251 |
+
nn.GELU(),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Mask prediction heads (one per mask)
|
| 255 |
+
self.output_hypernetworks_mlps = [
|
| 256 |
+
MLPBlock(transformer_dim, transformer_dim // 8, nn.GELU)
|
| 257 |
+
for _ in range(self.num_mask_tokens)
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
# IoU prediction head
|
| 261 |
+
self.iou_prediction_head = MLPBlock(
|
| 262 |
+
transformer_dim, iou_head_hidden_dim, nn.GELU
|
| 263 |
+
)
|
| 264 |
+
self.iou_prediction_linear = nn.Linear(iou_head_hidden_dim, self.num_mask_tokens)
|
| 265 |
+
|
| 266 |
+
def forward(
|
| 267 |
+
self,
|
| 268 |
+
image_embeddings: mx.array,
|
| 269 |
+
image_pe: mx.array,
|
| 270 |
+
sparse_prompt_embeddings: mx.array,
|
| 271 |
+
dense_prompt_embeddings: mx.array,
|
| 272 |
+
multimask_output: bool = True,
|
| 273 |
+
) -> Tuple[mx.array, mx.array]:
|
| 274 |
+
"""
|
| 275 |
+
Predict masks from image and prompt embeddings
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
image_embeddings: (B, H, W, C) from vision encoder
|
| 279 |
+
image_pe: (B, H, W, C) positional encoding for image
|
| 280 |
+
sparse_prompt_embeddings: (B, N, C) point/box embeddings
|
| 281 |
+
dense_prompt_embeddings: (B, H, W, C) mask embeddings
|
| 282 |
+
multimask_output: Return 3 masks or 1 mask
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
masks: (B, num_masks, H, W) predicted masks
|
| 286 |
+
iou_pred: (B, num_masks) quality scores
|
| 287 |
+
"""
|
| 288 |
+
B, H, W, C = image_embeddings.shape
|
| 289 |
+
|
| 290 |
+
# Flatten image embeddings and PE
|
| 291 |
+
image_embeddings_flat = image_embeddings.reshape(B, H * W, C)
|
| 292 |
+
image_pe_flat = image_pe.reshape(B, H * W, C)
|
| 293 |
+
|
| 294 |
+
# Concatenate output tokens
|
| 295 |
+
iou_token_out = self.iou_token.weight.reshape(1, 1, -1).broadcast_to(
|
| 296 |
+
(B, 1, self.transformer_dim)
|
| 297 |
+
)
|
| 298 |
+
mask_tokens_out = self.mask_tokens.weight.reshape(1, -1, self.transformer_dim).broadcast_to(
|
| 299 |
+
(B, self.num_mask_tokens, self.transformer_dim)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Combine all prompt tokens: [IoU token, mask tokens, sparse prompts]
|
| 303 |
+
tokens = mx.concatenate(
|
| 304 |
+
[iou_token_out, mask_tokens_out, sparse_prompt_embeddings], axis=1
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Add dense prompt embeddings to image
|
| 308 |
+
src = image_embeddings_flat + dense_prompt_embeddings.reshape(B, H * W, C)
|
| 309 |
+
|
| 310 |
+
# Run through transformer
|
| 311 |
+
hs, src = self.transformer(src, image_pe_flat, tokens)
|
| 312 |
+
|
| 313 |
+
# Extract tokens
|
| 314 |
+
iou_token_out = hs[:, 0:1, :]
|
| 315 |
+
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
| 316 |
+
|
| 317 |
+
# Upscale image embeddings
|
| 318 |
+
# Reshape to (B, H, W, C) for upsampling
|
| 319 |
+
src = src.reshape(B, H, W, C)
|
| 320 |
+
upscaled_embedding = self.output_upscaling(src) # (B, H*4, W*4, C//8)
|
| 321 |
+
|
| 322 |
+
B_up, H_up, W_up, C_up = upscaled_embedding.shape
|
| 323 |
+
|
| 324 |
+
# Predict masks using hypernetworks
|
| 325 |
+
masks = []
|
| 326 |
+
for i in range(self.num_mask_tokens):
|
| 327 |
+
# Get mask token features
|
| 328 |
+
mask_features = self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
| 329 |
+
# (B, C//8)
|
| 330 |
+
|
| 331 |
+
# Expand to spatial dimensions and compute dot product
|
| 332 |
+
mask_features = mask_features.reshape(B, 1, 1, C_up)
|
| 333 |
+
mask = (upscaled_embedding * mask_features).sum(axis=-1) # (B, H_up, W_up)
|
| 334 |
+
masks.append(mask)
|
| 335 |
+
|
| 336 |
+
masks = mx.stack(masks, axis=1) # (B, num_masks, H_up, W_up)
|
| 337 |
+
|
| 338 |
+
# Predict IoU scores
|
| 339 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
| 340 |
+
iou_pred = self.iou_prediction_linear(iou_pred).squeeze(1) # (B, num_masks)
|
| 341 |
+
|
| 342 |
+
# Select correct masks
|
| 343 |
+
if multimask_output:
|
| 344 |
+
# Return 3 multi-masks
|
| 345 |
+
mask_slice = slice(1, None)
|
| 346 |
+
else:
|
| 347 |
+
# Return single mask
|
| 348 |
+
mask_slice = slice(0, 1)
|
| 349 |
+
|
| 350 |
+
masks = masks[:, mask_slice, :, :]
|
| 351 |
+
iou_pred = iou_pred[:, mask_slice]
|
| 352 |
+
|
| 353 |
+
return masks, iou_pred
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def create_mask_decoder(
|
| 357 |
+
transformer_dim: int = 256,
|
| 358 |
+
num_multimask_outputs: int = 3,
|
| 359 |
+
) -> MaskDecoder:
|
| 360 |
+
"""
|
| 361 |
+
Factory function to create SAM3 mask decoder
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
transformer_dim: Feature dimension
|
| 365 |
+
num_multimask_outputs: Number of masks to output
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
MaskDecoder instance
|
| 369 |
+
"""
|
| 370 |
+
return MaskDecoder(
|
| 371 |
+
transformer_dim=transformer_dim,
|
| 372 |
+
num_multimask_outputs=num_multimask_outputs,
|
| 373 |
+
)
|
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM3 Prompt Encoder - Complete MLX Implementation
|
| 3 |
+
|
| 4 |
+
Encodes different types of user prompts:
|
| 5 |
+
- Points (clicks): Positive/negative points with coordinates
|
| 6 |
+
- Boxes: Bounding box coordinates (top-left, bottom-right)
|
| 7 |
+
- Masks: Dense mask inputs
|
| 8 |
+
|
| 9 |
+
Outputs:
|
| 10 |
+
- Sparse embeddings: Point and box prompt embeddings
|
| 11 |
+
- Dense embeddings: Mask prompt embeddings
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import mlx.core as mx
|
| 15 |
+
import mlx.nn as nn
|
| 16 |
+
from mlx.nn import Module
|
| 17 |
+
from typing import Optional, Tuple, List
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PositionEmbeddingRandom(Module):
|
| 22 |
+
"""
|
| 23 |
+
Positional encoding using random spatial frequencies
|
| 24 |
+
|
| 25 |
+
Similar to Fourier features - maps 2D coordinates to high-dimensional space
|
| 26 |
+
using learned frequency basis.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None):
|
| 30 |
+
super().__init__()
|
| 31 |
+
if scale is None or scale <= 0.0:
|
| 32 |
+
scale = 1.0
|
| 33 |
+
self.scale = scale
|
| 34 |
+
|
| 35 |
+
# Random frequency matrix
|
| 36 |
+
# Each row is a 2D frequency vector
|
| 37 |
+
self.positional_encoding_gaussian_matrix = mx.random.normal(
|
| 38 |
+
shape=(2, num_pos_feats)
|
| 39 |
+
) * scale
|
| 40 |
+
|
| 41 |
+
def _pe_encoding(self, coords: mx.array) -> mx.array:
|
| 42 |
+
"""
|
| 43 |
+
Positionally encode points normalized to [0, 1]
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
coords: (B, N, 2) coordinates in [0, 1] range
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
(B, N, num_pos_feats * 2) positional encoding
|
| 50 |
+
"""
|
| 51 |
+
# coords is (B, N, 2)
|
| 52 |
+
# Multiply by frequency matrix: (B, N, 2) @ (2, num_pos_feats) -> (B, N, num_pos_feats)
|
| 53 |
+
coords_scaled = coords * 2 * math.pi
|
| 54 |
+
|
| 55 |
+
# Project through random frequencies
|
| 56 |
+
# coords_scaled: (B, N, 2), matrix: (2, num_pos_feats)
|
| 57 |
+
projected = coords_scaled @ self.positional_encoding_gaussian_matrix
|
| 58 |
+
|
| 59 |
+
# Apply sin and cos
|
| 60 |
+
sin_proj = mx.sin(projected)
|
| 61 |
+
cos_proj = mx.cos(projected)
|
| 62 |
+
|
| 63 |
+
# Concatenate: (B, N, num_pos_feats * 2)
|
| 64 |
+
return mx.concatenate([sin_proj, cos_proj], axis=-1)
|
| 65 |
+
|
| 66 |
+
def forward(self, size: Tuple[int, int]) -> mx.array:
|
| 67 |
+
"""
|
| 68 |
+
Generate positional encoding for a 2D grid
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
size: (H, W) grid size
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
(H, W, C) positional encoding
|
| 75 |
+
"""
|
| 76 |
+
h, w = size
|
| 77 |
+
device = self.positional_encoding_gaussian_matrix.device
|
| 78 |
+
|
| 79 |
+
# Create coordinate grid
|
| 80 |
+
# y_embed: (H, W), x_embed: (H, W)
|
| 81 |
+
y_embed = mx.arange(h, dtype=mx.float32).reshape(-1, 1).broadcast_to((h, w))
|
| 82 |
+
x_embed = mx.arange(w, dtype=mx.float32).reshape(1, -1).broadcast_to((h, w))
|
| 83 |
+
|
| 84 |
+
# Normalize to [0, 1]
|
| 85 |
+
y_embed = y_embed / h
|
| 86 |
+
x_embed = x_embed / w
|
| 87 |
+
|
| 88 |
+
# Stack to (H, W, 2)
|
| 89 |
+
coords = mx.stack([x_embed, y_embed], axis=-1)
|
| 90 |
+
|
| 91 |
+
# Encode: (H, W, 2) -> (H, W, C)
|
| 92 |
+
# Add batch dimension, encode, remove batch dimension
|
| 93 |
+
coords = coords.reshape(1, h * w, 2)
|
| 94 |
+
pe = self._pe_encoding(coords)
|
| 95 |
+
pe = pe.reshape(h, w, -1)
|
| 96 |
+
|
| 97 |
+
return pe
|
| 98 |
+
|
| 99 |
+
def forward_with_coords(
|
| 100 |
+
self, coords_input: mx.array, image_size: Tuple[int, int]
|
| 101 |
+
) -> mx.array:
|
| 102 |
+
"""
|
| 103 |
+
Encode arbitrary point coordinates
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
coords_input: (B, N, 2) in pixel coordinates
|
| 107 |
+
image_size: (H, W) image dimensions for normalization
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
(B, N, C) positional encodings
|
| 111 |
+
"""
|
| 112 |
+
# Normalize coordinates to [0, 1]
|
| 113 |
+
coords = coords_input.astype(mx.float32)
|
| 114 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1] # x / W
|
| 115 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0] # y / H
|
| 116 |
+
|
| 117 |
+
return self._pe_encoding(coords)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PromptEncoder(Module):
|
| 121 |
+
"""
|
| 122 |
+
Complete SAM3 Prompt Encoder
|
| 123 |
+
|
| 124 |
+
Encodes prompts into embeddings for the mask decoder:
|
| 125 |
+
- Points: Sparse embeddings with learned type (positive/negative)
|
| 126 |
+
- Boxes: Sparse embeddings for corners (top-left, bottom-right)
|
| 127 |
+
- Masks: Dense embeddings from downsampled mask
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
embed_dim: Channel dimension for embeddings
|
| 131 |
+
image_embedding_size: Size of image embeddings from encoder
|
| 132 |
+
input_image_size: Original input image size
|
| 133 |
+
mask_in_chans: Input channels for mask encoder (default 16)
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
embed_dim: int,
|
| 139 |
+
image_embedding_size: Tuple[int, int],
|
| 140 |
+
input_image_size: Tuple[int, int],
|
| 141 |
+
mask_in_chans: int = 16,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.embed_dim = embed_dim
|
| 145 |
+
self.input_image_size = input_image_size
|
| 146 |
+
self.image_embedding_size = image_embedding_size
|
| 147 |
+
|
| 148 |
+
# Positional encoding for points and boxes
|
| 149 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
| 150 |
+
|
| 151 |
+
# Learnable embeddings for different prompt types
|
| 152 |
+
self.num_point_embeddings = 4 # pos, neg, top-left corner, bottom-right corner
|
| 153 |
+
self.point_embeddings = [
|
| 154 |
+
nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
# Embedding for "no mask" case
|
| 158 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
| 159 |
+
|
| 160 |
+
# Mask downsampling encoder
|
| 161 |
+
# Downsample mask from input_image_size to image_embedding_size
|
| 162 |
+
self.mask_downscaling = nn.Sequential(
|
| 163 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
| 164 |
+
nn.LayerNorm(mask_in_chans // 4),
|
| 165 |
+
nn.GELU(),
|
| 166 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
| 167 |
+
nn.LayerNorm(mask_in_chans),
|
| 168 |
+
nn.GELU(),
|
| 169 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# No mask embedding (used when no mask prompt provided)
|
| 173 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
| 174 |
+
|
| 175 |
+
def get_dense_pe(self) -> mx.array:
|
| 176 |
+
"""
|
| 177 |
+
Get positional encoding for image embedding grid
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
(H, W, C) dense positional encoding
|
| 181 |
+
"""
|
| 182 |
+
return self.pe_layer(self.image_embedding_size)
|
| 183 |
+
|
| 184 |
+
def _embed_points(
|
| 185 |
+
self,
|
| 186 |
+
points: mx.array,
|
| 187 |
+
labels: mx.array,
|
| 188 |
+
pad: bool,
|
| 189 |
+
) -> mx.array:
|
| 190 |
+
"""
|
| 191 |
+
Embed point prompts
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
points: (B, N, 2) point coordinates
|
| 195 |
+
labels: (B, N) point labels (0=negative, 1=positive)
|
| 196 |
+
pad: Whether to pad with "not a point" embedding
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
(B, N, C) or (B, N+1, C) point embeddings
|
| 200 |
+
"""
|
| 201 |
+
# Add positional encoding to points
|
| 202 |
+
points = points + 0.5 # Shift to center of pixel
|
| 203 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
| 204 |
+
points, self.input_image_size
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Add learned type embedding based on label
|
| 208 |
+
# labels: 0 = negative, 1 = positive
|
| 209 |
+
B, N, C = point_embedding.shape
|
| 210 |
+
for b in range(B):
|
| 211 |
+
for n in range(N):
|
| 212 |
+
label = int(labels[b, n].item())
|
| 213 |
+
if label == 0:
|
| 214 |
+
# Negative point
|
| 215 |
+
type_embed = self.point_embeddings[0].weight
|
| 216 |
+
elif label == 1:
|
| 217 |
+
# Positive point
|
| 218 |
+
type_embed = self.point_embeddings[1].weight
|
| 219 |
+
else:
|
| 220 |
+
# Unknown, use negative
|
| 221 |
+
type_embed = self.point_embeddings[0].weight
|
| 222 |
+
|
| 223 |
+
point_embedding[b, n, :] = point_embedding[b, n, :] + type_embed.reshape(-1)
|
| 224 |
+
|
| 225 |
+
# Pad with "not a point" embedding if requested
|
| 226 |
+
if pad:
|
| 227 |
+
padding_point = self.not_a_point_embed.weight.reshape(1, 1, -1).broadcast_to(
|
| 228 |
+
(B, 1, C)
|
| 229 |
+
)
|
| 230 |
+
point_embedding = mx.concatenate([point_embedding, padding_point], axis=1)
|
| 231 |
+
|
| 232 |
+
return point_embedding
|
| 233 |
+
|
| 234 |
+
def _embed_boxes(self, boxes: mx.array) -> mx.array:
|
| 235 |
+
"""
|
| 236 |
+
Embed box prompts
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
boxes: (B, 4) boxes as [x0, y0, x1, y1]
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
(B, 2, C) corner embeddings [top-left, bottom-right]
|
| 243 |
+
"""
|
| 244 |
+
B = boxes.shape[0]
|
| 245 |
+
boxes = boxes + 0.5 # Shift to pixel centers
|
| 246 |
+
|
| 247 |
+
# Split into corners: (B, 2, 2)
|
| 248 |
+
coords = mx.stack(
|
| 249 |
+
[
|
| 250 |
+
boxes[:, :2], # top-left [x0, y0]
|
| 251 |
+
boxes[:, 2:], # bottom-right [x1, y1]
|
| 252 |
+
],
|
| 253 |
+
axis=1,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Get positional encoding for corners
|
| 257 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
| 258 |
+
coords, self.input_image_size
|
| 259 |
+
) # (B, 2, C)
|
| 260 |
+
|
| 261 |
+
# Add learned corner type embeddings
|
| 262 |
+
corner_embedding[:, 0, :] = corner_embedding[:, 0, :] + self.point_embeddings[2].weight.reshape(-1)
|
| 263 |
+
corner_embedding[:, 1, :] = corner_embedding[:, 1, :] + self.point_embeddings[3].weight.reshape(-1)
|
| 264 |
+
|
| 265 |
+
return corner_embedding
|
| 266 |
+
|
| 267 |
+
def _embed_masks(self, masks: mx.array) -> mx.array:
|
| 268 |
+
"""
|
| 269 |
+
Embed mask prompts
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
masks: (B, 1, H, W) dense masks
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
(B, H_emb, W_emb, C) downsampled mask embeddings
|
| 276 |
+
"""
|
| 277 |
+
# Downsample mask to embedding size
|
| 278 |
+
mask_embedding = self.mask_downscaling(masks)
|
| 279 |
+
return mask_embedding
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self,
|
| 283 |
+
points: Optional[Tuple[mx.array, mx.array]] = None,
|
| 284 |
+
boxes: Optional[mx.array] = None,
|
| 285 |
+
masks: Optional[mx.array] = None,
|
| 286 |
+
) -> Tuple[mx.array, mx.array]:
|
| 287 |
+
"""
|
| 288 |
+
Encode prompts into sparse and dense embeddings
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
points: Optional tuple of (coords, labels)
|
| 292 |
+
- coords: (B, N, 2) point coordinates
|
| 293 |
+
- labels: (B, N) point labels (0=neg, 1=pos)
|
| 294 |
+
boxes: Optional (B, 4) boxes as [x0, y0, x1, y1]
|
| 295 |
+
masks: Optional (B, 1, H, W) mask prompts
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
sparse_embeddings: (B, N_sparse, C) point/box embeddings
|
| 299 |
+
dense_embeddings: (B, H_emb, W_emb, C) mask embeddings
|
| 300 |
+
"""
|
| 301 |
+
bs = 1 # Default batch size
|
| 302 |
+
|
| 303 |
+
# Handle sparse prompts (points and boxes)
|
| 304 |
+
sparse_embeddings_list = []
|
| 305 |
+
|
| 306 |
+
if points is not None:
|
| 307 |
+
coords, labels = points
|
| 308 |
+
bs = coords.shape[0]
|
| 309 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
| 310 |
+
sparse_embeddings_list.append(point_embeddings)
|
| 311 |
+
|
| 312 |
+
if boxes is not None:
|
| 313 |
+
bs = boxes.shape[0]
|
| 314 |
+
box_embeddings = self._embed_boxes(boxes)
|
| 315 |
+
sparse_embeddings_list.append(box_embeddings)
|
| 316 |
+
|
| 317 |
+
# Concatenate all sparse embeddings
|
| 318 |
+
if len(sparse_embeddings_list) > 0:
|
| 319 |
+
sparse_embeddings = mx.concatenate(sparse_embeddings_list, axis=1)
|
| 320 |
+
else:
|
| 321 |
+
# No sparse prompts - use "not a point" embedding
|
| 322 |
+
sparse_embeddings = self.not_a_point_embed.weight.reshape(
|
| 323 |
+
1, 1, -1
|
| 324 |
+
).broadcast_to((bs, 1, self.embed_dim))
|
| 325 |
+
|
| 326 |
+
# Handle dense prompts (masks)
|
| 327 |
+
if masks is not None:
|
| 328 |
+
bs = masks.shape[0]
|
| 329 |
+
dense_embeddings = self._embed_masks(masks)
|
| 330 |
+
else:
|
| 331 |
+
# No mask prompt - broadcast no_mask_embed to image embedding size
|
| 332 |
+
H, W = self.image_embedding_size
|
| 333 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(
|
| 334 |
+
1, 1, 1, -1
|
| 335 |
+
).broadcast_to((bs, H, W, self.embed_dim))
|
| 336 |
+
|
| 337 |
+
return sparse_embeddings, dense_embeddings
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def create_prompt_encoder(
|
| 341 |
+
embed_dim: int = 256,
|
| 342 |
+
image_embedding_size: Tuple[int, int] = (64, 64),
|
| 343 |
+
input_image_size: Tuple[int, int] = (1024, 1024),
|
| 344 |
+
) -> PromptEncoder:
|
| 345 |
+
"""
|
| 346 |
+
Factory function to create SAM3 prompt encoder
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
embed_dim: Embedding dimension
|
| 350 |
+
image_embedding_size: Size of vision encoder output
|
| 351 |
+
input_image_size: Size of input images
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
PromptEncoder instance
|
| 355 |
+
"""
|
| 356 |
+
return PromptEncoder(
|
| 357 |
+
embed_dim=embed_dim,
|
| 358 |
+
image_embedding_size=image_embedding_size,
|
| 359 |
+
input_image_size=input_image_size,
|
| 360 |
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sam3-mlx"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Segment Anything Model 3 (SAM3) implemented in Apple MLX for native Metal acceleration"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.9"
|
| 11 |
+
license = {text = "MIT"}
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "SAM3 MLX Contributors"},
|
| 14 |
+
]
|
| 15 |
+
keywords = [
|
| 16 |
+
"segment-anything",
|
| 17 |
+
"sam3",
|
| 18 |
+
"mlx",
|
| 19 |
+
"apple-silicon",
|
| 20 |
+
"computer-vision",
|
| 21 |
+
"segmentation",
|
| 22 |
+
"metal",
|
| 23 |
+
"machine-learning",
|
| 24 |
+
"deep-learning",
|
| 25 |
+
]
|
| 26 |
+
classifiers = [
|
| 27 |
+
"Development Status :: 4 - Beta",
|
| 28 |
+
"Intended Audience :: Developers",
|
| 29 |
+
"Intended Audience :: Science/Research",
|
| 30 |
+
"License :: OSI Approved :: MIT License",
|
| 31 |
+
"Operating System :: MacOS",
|
| 32 |
+
"Programming Language :: Python :: 3",
|
| 33 |
+
"Programming Language :: Python :: 3.9",
|
| 34 |
+
"Programming Language :: Python :: 3.10",
|
| 35 |
+
"Programming Language :: Python :: 3.11",
|
| 36 |
+
"Programming Language :: Python :: 3.12",
|
| 37 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 38 |
+
"Topic :: Scientific/Engineering :: Image Recognition",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
dependencies = [
|
| 42 |
+
"mlx>=0.20.0",
|
| 43 |
+
"numpy>=1.23.0",
|
| 44 |
+
"pillow>=9.0.0",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
[project.optional-dependencies]
|
| 48 |
+
dev = [
|
| 49 |
+
"pytest>=7.0",
|
| 50 |
+
"pytest-cov>=4.0",
|
| 51 |
+
"black>=23.0",
|
| 52 |
+
"ruff>=0.1.0",
|
| 53 |
+
"mypy>=1.0",
|
| 54 |
+
]
|
| 55 |
+
examples = [
|
| 56 |
+
"matplotlib>=3.5.0",
|
| 57 |
+
"tqdm>=4.65.0",
|
| 58 |
+
]
|
| 59 |
+
all = [
|
| 60 |
+
"sam3-mlx[dev,examples]",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
[project.urls]
|
| 64 |
+
Homepage = "https://github.com/yourusername/sam3-mlx"
|
| 65 |
+
Repository = "https://github.com/yourusername/sam3-mlx"
|
| 66 |
+
Documentation = "https://github.com/yourusername/sam3-mlx#readme"
|
| 67 |
+
"Bug Tracker" = "https://github.com/yourusername/sam3-mlx/issues"
|
| 68 |
+
|
| 69 |
+
[project.scripts]
|
| 70 |
+
sam3-segment = "sam3_mlx.cli:main"
|
| 71 |
+
|
| 72 |
+
[tool.setuptools]
|
| 73 |
+
packages = ["sam3_mlx", "sam3_mlx.models", "sam3_mlx.utils"]
|
| 74 |
+
|
| 75 |
+
[tool.setuptools.package-data]
|
| 76 |
+
sam3_mlx = ["py.typed"]
|
| 77 |
+
|
| 78 |
+
[tool.black]
|
| 79 |
+
line-length = 100
|
| 80 |
+
target-version = ['py39', 'py310', 'py311']
|
| 81 |
+
include = '\.pyi?$'
|
| 82 |
+
|
| 83 |
+
[tool.ruff]
|
| 84 |
+
line-length = 100
|
| 85 |
+
target-version = "py39"
|
| 86 |
+
select = ["E", "F", "I", "N", "W"]
|
| 87 |
+
ignore = ["E501"]
|
| 88 |
+
|
| 89 |
+
[tool.mypy]
|
| 90 |
+
python_version = "3.9"
|
| 91 |
+
warn_return_any = true
|
| 92 |
+
warn_unused_configs = true
|
| 93 |
+
disallow_untyped_defs = true
|
| 94 |
+
disallow_incomplete_defs = true
|
| 95 |
+
|
| 96 |
+
[tool.pytest.ini_options]
|
| 97 |
+
testpaths = ["tests"]
|
| 98 |
+
python_files = "test_*.py"
|
| 99 |
+
python_classes = "Test*"
|
| 100 |
+
python_functions = "test_*"
|
| 101 |
+
addopts = "-v --cov=sam3_mlx --cov-report=term-missing"
|
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
mlx>=0.20.0
|
| 3 |
+
numpy>=1.23.0
|
| 4 |
+
pillow>=9.0.0
|
| 5 |
+
|
| 6 |
+
# Optional: for examples
|
| 7 |
+
matplotlib>=3.5.0
|
| 8 |
+
tqdm>=4.65.0
|
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM3 MLX - Main Model Class
|
| 3 |
+
|
| 4 |
+
Complete Segment Anything Model 3 implementation in MLX
|
| 5 |
+
Ties together: Vision Encoder, Prompt Encoder, Mask Decoder
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import mlx.core as mx
|
| 9 |
+
import mlx.nn as nn
|
| 10 |
+
from mlx.nn import Module
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import json
|
| 13 |
+
import numpy as np
|
| 14 |
+
from typing import Dict, Optional, Tuple, Any, List
|
| 15 |
+
from .hiera import create_hiera_base, create_hiera_large
|
| 16 |
+
from .prompt_encoder import create_prompt_encoder, PromptEncoder
|
| 17 |
+
from .mask_decoder import create_mask_decoder, MaskDecoder
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SAM3MLX(Module):
|
| 21 |
+
"""
|
| 22 |
+
Complete SAM3 Model in MLX
|
| 23 |
+
|
| 24 |
+
Architecture:
|
| 25 |
+
1. Vision Encoder (Hiera) - Encodes image to features
|
| 26 |
+
2. Prompt Encoder - Encodes user prompts (points/boxes/masks)
|
| 27 |
+
3. Mask Decoder - Predicts segmentation masks
|
| 28 |
+
|
| 29 |
+
Full production-ready implementation with all components integrated.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
config: Optional[Dict[str, Any]] = None,
|
| 35 |
+
image_encoder_variant: str = "base",
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
if config is None:
|
| 40 |
+
config = self.default_config()
|
| 41 |
+
|
| 42 |
+
self.config = config
|
| 43 |
+
|
| 44 |
+
# Extract configuration
|
| 45 |
+
self.image_size = config.get("image_size", 1024)
|
| 46 |
+
self.embed_dim = config.get("prompt_embed_dim", 256)
|
| 47 |
+
|
| 48 |
+
# Vision encoder (Hiera)
|
| 49 |
+
print("ποΈ Initializing Hiera vision encoder...")
|
| 50 |
+
if image_encoder_variant == "large":
|
| 51 |
+
self.vision_encoder = create_hiera_large()
|
| 52 |
+
vision_embed_dim = 1536
|
| 53 |
+
else:
|
| 54 |
+
self.vision_encoder = create_hiera_base()
|
| 55 |
+
vision_embed_dim = 1024
|
| 56 |
+
|
| 57 |
+
# Calculate image embedding size after patch embedding and downsampling
|
| 58 |
+
# Hiera: patch_size=14, then 3 downsample layers (2x each)
|
| 59 |
+
# 1024 -> 73 patches -> 73/2 -> 36/2 -> 18/2 -> 9
|
| 60 |
+
# Actually it's: 1024/14 = 73.14 β 73 -> /2^3 = ~9
|
| 61 |
+
patch_grid_size = self.image_size // config.get("patch_size", 14)
|
| 62 |
+
num_downsample = len(config.get("embed_dims", [256, 512, 1024, 1024])) - 1
|
| 63 |
+
image_embedding_size = patch_grid_size // (2 ** num_downsample)
|
| 64 |
+
self.image_embedding_size = (image_embedding_size, image_embedding_size)
|
| 65 |
+
|
| 66 |
+
print(f" Image embedding grid: {self.image_embedding_size}")
|
| 67 |
+
|
| 68 |
+
# Prompt encoder
|
| 69 |
+
print("ποΈ Initializing prompt encoder...")
|
| 70 |
+
self.prompt_encoder = create_prompt_encoder(
|
| 71 |
+
embed_dim=self.embed_dim,
|
| 72 |
+
image_embedding_size=self.image_embedding_size,
|
| 73 |
+
input_image_size=(self.image_size, self.image_size),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Mask decoder
|
| 77 |
+
print("ποΈ Initializing mask decoder...")
|
| 78 |
+
self.mask_decoder = create_mask_decoder(
|
| 79 |
+
transformer_dim=self.embed_dim,
|
| 80 |
+
num_multimask_outputs=3,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Projection from vision encoder to decoder dimension
|
| 84 |
+
if vision_embed_dim != self.embed_dim:
|
| 85 |
+
self.neck = nn.Sequential(
|
| 86 |
+
nn.Conv2d(vision_embed_dim, self.embed_dim, kernel_size=1, bias=False),
|
| 87 |
+
nn.LayerNorm(self.embed_dim),
|
| 88 |
+
nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=3, padding=1, bias=False),
|
| 89 |
+
nn.LayerNorm(self.embed_dim),
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
self.neck = nn.Identity()
|
| 93 |
+
|
| 94 |
+
print(f"β
SAM3 MLX initialized")
|
| 95 |
+
print(f" Vision backbone: Hiera-{image_encoder_variant.capitalize()}")
|
| 96 |
+
print(f" Embed dims: {config.get('embed_dims', 'default')}")
|
| 97 |
+
print(f" Prompt embed dim: {self.embed_dim}")
|
| 98 |
+
print(f" Image size: {self.image_size}x{self.image_size}")
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def default_config() -> Dict[str, Any]:
|
| 102 |
+
"""Default SAM3 configuration"""
|
| 103 |
+
return {
|
| 104 |
+
"image_size": 1024,
|
| 105 |
+
"patch_size": 14,
|
| 106 |
+
"embed_dims": [256, 512, 1024, 1024],
|
| 107 |
+
"depths": [2, 8, 16, 6],
|
| 108 |
+
"num_heads": [4, 8, 16, 16],
|
| 109 |
+
"mlp_ratio": 4.0,
|
| 110 |
+
"prompt_embed_dim": 256,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def encode_image(self, image: mx.array) -> mx.array:
|
| 114 |
+
"""
|
| 115 |
+
Encode image to feature embeddings
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
image: (B, H, W, C) in NHWC format
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
(B, H_emb, W_emb, C) image features
|
| 122 |
+
"""
|
| 123 |
+
# Get vision encoder features: (B, num_patches, embed_dim)
|
| 124 |
+
features = self.vision_encoder(image)
|
| 125 |
+
|
| 126 |
+
# Reshape to spatial format
|
| 127 |
+
B, N, C = features.shape
|
| 128 |
+
H, W = self.image_embedding_size
|
| 129 |
+
features = features.reshape(B, H, W, C)
|
| 130 |
+
|
| 131 |
+
# Project to decoder dimension
|
| 132 |
+
features = self.neck(features)
|
| 133 |
+
|
| 134 |
+
return features
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
image: mx.array,
|
| 139 |
+
points: Optional[Tuple[mx.array, mx.array]] = None,
|
| 140 |
+
boxes: Optional[mx.array] = None,
|
| 141 |
+
masks: Optional[mx.array] = None,
|
| 142 |
+
multimask_output: bool = True,
|
| 143 |
+
) -> Dict[str, mx.array]:
|
| 144 |
+
"""
|
| 145 |
+
Full forward pass with prompts
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
image: (B, H, W, C) input image in NHWC format
|
| 149 |
+
points: Optional tuple of (coords, labels)
|
| 150 |
+
- coords: (B, N, 2) point coordinates
|
| 151 |
+
- labels: (B, N) point labels (0=neg, 1=pos)
|
| 152 |
+
boxes: Optional (B, 4) boxes as [x0, y0, x1, y1]
|
| 153 |
+
masks: Optional (B, 1, H, W) mask prompts
|
| 154 |
+
multimask_output: Return 3 masks (True) or 1 mask (False)
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Dictionary containing:
|
| 158 |
+
- masks: (B, num_masks, H, W) predicted masks
|
| 159 |
+
- iou_predictions: (B, num_masks) quality scores
|
| 160 |
+
- low_res_masks: (B, num_masks, H_low, W_low) low-res masks
|
| 161 |
+
"""
|
| 162 |
+
# Encode image
|
| 163 |
+
image_embeddings = self.encode_image(image) # (B, H_emb, W_emb, C)
|
| 164 |
+
|
| 165 |
+
# Encode prompts
|
| 166 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
| 167 |
+
points=points,
|
| 168 |
+
boxes=boxes,
|
| 169 |
+
masks=masks,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Get dense positional encoding for image
|
| 173 |
+
image_pe = self.prompt_encoder.get_dense_pe() # (H_emb, W_emb, C)
|
| 174 |
+
# Broadcast to batch size
|
| 175 |
+
B = image_embeddings.shape[0]
|
| 176 |
+
image_pe = image_pe.reshape(1, *image_pe.shape).broadcast_to(
|
| 177 |
+
(B, *image_pe.shape)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Predict masks
|
| 181 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
| 182 |
+
image_embeddings=image_embeddings,
|
| 183 |
+
image_pe=image_pe,
|
| 184 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 185 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 186 |
+
multimask_output=multimask_output,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Upsample masks to input resolution
|
| 190 |
+
# low_res_masks: (B, num_masks, 256, 256)
|
| 191 |
+
# Need to upsample to (B, num_masks, 1024, 1024)
|
| 192 |
+
masks = self._upsample_masks(low_res_masks, self.image_size)
|
| 193 |
+
|
| 194 |
+
return {
|
| 195 |
+
"masks": masks,
|
| 196 |
+
"iou_predictions": iou_predictions,
|
| 197 |
+
"low_res_masks": low_res_masks,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
def _upsample_masks(self, masks: mx.array, target_size: int) -> mx.array:
|
| 201 |
+
"""
|
| 202 |
+
Upsample masks to target size using bilinear interpolation
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
masks: (B, num_masks, H, W)
|
| 206 |
+
target_size: Target spatial size
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
(B, num_masks, target_size, target_size)
|
| 210 |
+
"""
|
| 211 |
+
B, num_masks, H, W = masks.shape
|
| 212 |
+
|
| 213 |
+
# For now, use simple nearest neighbor upsampling
|
| 214 |
+
# TODO: Implement proper bilinear interpolation in MLX
|
| 215 |
+
scale = target_size // H
|
| 216 |
+
|
| 217 |
+
# Repeat each pixel scale x scale times
|
| 218 |
+
masks_up = mx.repeat(masks, scale, axis=2) # Upsample height
|
| 219 |
+
masks_up = mx.repeat(masks_up, scale, axis=3) # Upsample width
|
| 220 |
+
|
| 221 |
+
return masks_up
|
| 222 |
+
|
| 223 |
+
def predict(
|
| 224 |
+
self,
|
| 225 |
+
image: mx.array,
|
| 226 |
+
point_coords: Optional[mx.array] = None,
|
| 227 |
+
point_labels: Optional[mx.array] = None,
|
| 228 |
+
box: Optional[mx.array] = None,
|
| 229 |
+
mask_input: Optional[mx.array] = None,
|
| 230 |
+
multimask_output: bool = True,
|
| 231 |
+
) -> Dict[str, mx.array]:
|
| 232 |
+
"""
|
| 233 |
+
Convenience method for prediction
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
image: (H, W, C) or (B, H, W, C) input image
|
| 237 |
+
point_coords: Optional (N, 2) or (B, N, 2) point coordinates
|
| 238 |
+
point_labels: Optional (N,) or (B, N) point labels
|
| 239 |
+
box: Optional (4,) or (B, 4) bounding box
|
| 240 |
+
mask_input: Optional (1, H, W) or (B, 1, H, W) mask
|
| 241 |
+
multimask_output: Return multiple masks
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Prediction dictionary
|
| 245 |
+
"""
|
| 246 |
+
# Add batch dimension if needed
|
| 247 |
+
if len(image.shape) == 3:
|
| 248 |
+
image = image.reshape(1, *image.shape)
|
| 249 |
+
|
| 250 |
+
# Prepare points
|
| 251 |
+
points = None
|
| 252 |
+
if point_coords is not None and point_labels is not None:
|
| 253 |
+
if len(point_coords.shape) == 2:
|
| 254 |
+
point_coords = point_coords.reshape(1, *point_coords.shape)
|
| 255 |
+
if len(point_labels.shape) == 1:
|
| 256 |
+
point_labels = point_labels.reshape(1, *point_labels.shape)
|
| 257 |
+
points = (point_coords, point_labels)
|
| 258 |
+
|
| 259 |
+
# Prepare box
|
| 260 |
+
boxes = None
|
| 261 |
+
if box is not None:
|
| 262 |
+
if len(box.shape) == 1:
|
| 263 |
+
box = box.reshape(1, -1)
|
| 264 |
+
boxes = box
|
| 265 |
+
|
| 266 |
+
# Prepare mask
|
| 267 |
+
masks = None
|
| 268 |
+
if mask_input is not None:
|
| 269 |
+
if len(mask_input.shape) == 3:
|
| 270 |
+
mask_input = mask_input.reshape(1, *mask_input.shape)
|
| 271 |
+
masks = mask_input
|
| 272 |
+
|
| 273 |
+
return self.forward(
|
| 274 |
+
image=image,
|
| 275 |
+
points=points,
|
| 276 |
+
boxes=boxes,
|
| 277 |
+
masks=masks,
|
| 278 |
+
multimask_output=multimask_output,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
@classmethod
|
| 282 |
+
def from_checkpoint(cls, checkpoint_dir: str):
|
| 283 |
+
"""
|
| 284 |
+
Load SAM3 from MLX checkpoint directory
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
checkpoint_dir: Path to directory containing:
|
| 288 |
+
- sam3_mlx_config.json
|
| 289 |
+
- sam3_mlx_weights.npz
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Loaded SAM3MLX model
|
| 293 |
+
"""
|
| 294 |
+
checkpoint_dir = Path(checkpoint_dir)
|
| 295 |
+
|
| 296 |
+
# Load config
|
| 297 |
+
config_path = checkpoint_dir / "sam3_mlx_config.json"
|
| 298 |
+
if not config_path.exists():
|
| 299 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 300 |
+
|
| 301 |
+
with open(config_path) as f:
|
| 302 |
+
config = json.load(f)
|
| 303 |
+
|
| 304 |
+
print(f"π Loading SAM3 from {checkpoint_dir}")
|
| 305 |
+
print(f" Config: {config.get('vision_backbone', 'unknown')} backbone")
|
| 306 |
+
|
| 307 |
+
# Create model
|
| 308 |
+
model = cls(config)
|
| 309 |
+
|
| 310 |
+
# Load weights
|
| 311 |
+
weights_path = checkpoint_dir / "sam3_mlx_weights.npz"
|
| 312 |
+
if weights_path.exists():
|
| 313 |
+
print(f"β³ Loading weights from {weights_path.name}...")
|
| 314 |
+
model.load_weights(str(weights_path))
|
| 315 |
+
else:
|
| 316 |
+
print(f"β οΈ Weights not found at {weights_path}, using random initialization")
|
| 317 |
+
|
| 318 |
+
return model
|
| 319 |
+
|
| 320 |
+
def load_weights(self, weights_path: str):
|
| 321 |
+
"""
|
| 322 |
+
Load converted MLX weights
|
| 323 |
+
|
| 324 |
+
This is a simplified version - full implementation would
|
| 325 |
+
properly map all weights to their corresponding layers.
|
| 326 |
+
"""
|
| 327 |
+
print(f"π₯ Loading weights from {weights_path}")
|
| 328 |
+
|
| 329 |
+
weights_np = np.load(weights_path)
|
| 330 |
+
|
| 331 |
+
# Filter vision encoder weights
|
| 332 |
+
vision_weights = {}
|
| 333 |
+
for name in weights_np.files:
|
| 334 |
+
if name.startswith('vision_encoder.'):
|
| 335 |
+
# Remove prefix
|
| 336 |
+
key = name.replace('vision_encoder.', '')
|
| 337 |
+
vision_weights[key] = mx.array(weights_np[name])
|
| 338 |
+
|
| 339 |
+
print(f"β
Loaded {len(vision_weights)} vision encoder parameters")
|
| 340 |
+
|
| 341 |
+
# TODO: Implement proper weight loading to all components
|
| 342 |
+
# For now, we've demonstrated the structure
|
| 343 |
+
|
| 344 |
+
return self
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def create_sam3_mlx(config: Optional[Dict] = None) -> SAM3MLX:
|
| 348 |
+
"""
|
| 349 |
+
Factory function to create SAM3 MLX model
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
config: Optional configuration dict
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
SAM3MLX model instance
|
| 356 |
+
"""
|
| 357 |
+
return SAM3MLX(config=config)
|
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for SAM3 MLX models
|
| 3 |
+
|
| 4 |
+
Validates that all model components work correctly
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import pytest
|
| 9 |
+
except ImportError:
|
| 10 |
+
pytest = None
|
| 11 |
+
|
| 12 |
+
import mlx.core as mx
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Add parent directory to path
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
from models.attention import MultiHeadAttentionRoPE, WindowedAttention, RoPEEmbedding
|
| 20 |
+
from models.hiera import HieraVisionEncoder, create_hiera_base
|
| 21 |
+
from models.prompt_encoder import PromptEncoder, create_prompt_encoder
|
| 22 |
+
from models.mask_decoder import MaskDecoder, create_mask_decoder
|
| 23 |
+
from models.sam3 import SAM3MLX
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestAttention:
|
| 27 |
+
"""Test attention modules"""
|
| 28 |
+
|
| 29 |
+
def test_rope_embedding(self):
|
| 30 |
+
"""Test RoPE embedding generation"""
|
| 31 |
+
rope = RoPEEmbedding(dim=64, max_seq_len=1024)
|
| 32 |
+
emb = rope.forward(seq_len=256)
|
| 33 |
+
|
| 34 |
+
assert emb.shape == (2, 256, 64), f"Wrong shape: {emb.shape}"
|
| 35 |
+
print("β
RoPE embedding test passed")
|
| 36 |
+
|
| 37 |
+
def test_multihead_attention_rope(self):
|
| 38 |
+
"""Test multi-head attention with RoPE"""
|
| 39 |
+
attn = MultiHeadAttentionRoPE(dim=256, num_heads=8, use_rope=True)
|
| 40 |
+
|
| 41 |
+
# Create dummy input
|
| 42 |
+
x = mx.random.normal((2, 64, 256)) # (batch, seq_len, dim)
|
| 43 |
+
|
| 44 |
+
# Forward pass
|
| 45 |
+
out = attn(x)
|
| 46 |
+
|
| 47 |
+
assert out.shape == x.shape, f"Wrong output shape: {out.shape}"
|
| 48 |
+
print("β
Multi-head attention RoPE test passed")
|
| 49 |
+
|
| 50 |
+
def test_windowed_attention(self):
|
| 51 |
+
"""Test windowed attention"""
|
| 52 |
+
attn = WindowedAttention(dim=256, num_heads=8, window_size=14)
|
| 53 |
+
|
| 54 |
+
x = mx.random.normal((2, 64, 256))
|
| 55 |
+
out = attn(x)
|
| 56 |
+
|
| 57 |
+
assert out.shape == x.shape
|
| 58 |
+
print("β
Windowed attention test passed")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TestHiera:
|
| 62 |
+
"""Test Hiera vision encoder"""
|
| 63 |
+
|
| 64 |
+
def test_hiera_base(self):
|
| 65 |
+
"""Test Hiera-Base encoder"""
|
| 66 |
+
encoder = create_hiera_base()
|
| 67 |
+
|
| 68 |
+
# Create dummy image (1024x1024 RGB in NHWC format)
|
| 69 |
+
image = mx.random.normal((1, 1024, 1024, 3))
|
| 70 |
+
|
| 71 |
+
# Forward pass
|
| 72 |
+
features = encoder(image)
|
| 73 |
+
|
| 74 |
+
# Check output shape
|
| 75 |
+
# After patch embedding (1024/14 = 73) and 3 downsample layers (73/8 = 9)
|
| 76 |
+
# Should be (1, 81, 1024) - approximately 9x9 grid
|
| 77 |
+
batch, num_patches, embed_dim = features.shape
|
| 78 |
+
|
| 79 |
+
assert batch == 1, f"Wrong batch size: {batch}"
|
| 80 |
+
assert embed_dim == 1024, f"Wrong embed dim: {embed_dim}"
|
| 81 |
+
# Approximately 9x9 = 81 patches
|
| 82 |
+
assert 70 < num_patches < 90, f"Wrong number of patches: {num_patches}"
|
| 83 |
+
|
| 84 |
+
print(f"β
Hiera-Base test passed - output shape: {features.shape}")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TestPromptEncoder:
|
| 88 |
+
"""Test prompt encoder"""
|
| 89 |
+
|
| 90 |
+
def test_point_encoding(self):
|
| 91 |
+
"""Test point prompt encoding"""
|
| 92 |
+
encoder = create_prompt_encoder(
|
| 93 |
+
embed_dim=256,
|
| 94 |
+
image_embedding_size=(64, 64),
|
| 95 |
+
input_image_size=(1024, 1024),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Create point prompts
|
| 99 |
+
point_coords = mx.array([[[512, 384]]]).astype(mx.float32) # (1, 1, 2)
|
| 100 |
+
point_labels = mx.array([[1]]).astype(mx.float32) # (1, 1)
|
| 101 |
+
|
| 102 |
+
sparse_emb, dense_emb = encoder(
|
| 103 |
+
points=(point_coords, point_labels),
|
| 104 |
+
boxes=None,
|
| 105 |
+
masks=None,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Check sparse embeddings (should include padding)
|
| 109 |
+
assert sparse_emb.shape[0] == 1 # batch
|
| 110 |
+
assert sparse_emb.shape[2] == 256 # embed_dim
|
| 111 |
+
|
| 112 |
+
# Check dense embeddings
|
| 113 |
+
assert dense_emb.shape == (1, 64, 64, 256)
|
| 114 |
+
|
| 115 |
+
print("β
Prompt encoder point test passed")
|
| 116 |
+
|
| 117 |
+
def test_box_encoding(self):
|
| 118 |
+
"""Test box prompt encoding"""
|
| 119 |
+
encoder = create_prompt_encoder(embed_dim=256)
|
| 120 |
+
|
| 121 |
+
# Create box prompt [x0, y0, x1, y1]
|
| 122 |
+
box = mx.array([[100, 100, 500, 500]]).astype(mx.float32)
|
| 123 |
+
|
| 124 |
+
sparse_emb, dense_emb = encoder(
|
| 125 |
+
points=None,
|
| 126 |
+
boxes=box,
|
| 127 |
+
masks=None,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Should have 2 corner embeddings
|
| 131 |
+
assert sparse_emb.shape[1] == 2
|
| 132 |
+
assert sparse_emb.shape[2] == 256
|
| 133 |
+
|
| 134 |
+
print("β
Prompt encoder box test passed")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TestMaskDecoder:
|
| 138 |
+
"""Test mask decoder"""
|
| 139 |
+
|
| 140 |
+
def test_mask_decoder(self):
|
| 141 |
+
"""Test mask decoder forward pass"""
|
| 142 |
+
decoder = create_mask_decoder(transformer_dim=256)
|
| 143 |
+
|
| 144 |
+
# Create dummy inputs
|
| 145 |
+
B, H, W, C = 1, 64, 64, 256
|
| 146 |
+
image_embeddings = mx.random.normal((B, H, W, C))
|
| 147 |
+
image_pe = mx.random.normal((B, H, W, C))
|
| 148 |
+
sparse_prompt_embeddings = mx.random.normal((B, 3, C))
|
| 149 |
+
dense_prompt_embeddings = mx.zeros((B, H, W, C))
|
| 150 |
+
|
| 151 |
+
# Forward pass
|
| 152 |
+
masks, iou_pred = decoder(
|
| 153 |
+
image_embeddings=image_embeddings,
|
| 154 |
+
image_pe=image_pe,
|
| 155 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
| 156 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
| 157 |
+
multimask_output=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Check outputs
|
| 161 |
+
assert masks.shape[0] == B
|
| 162 |
+
assert masks.shape[1] == 3 # 3 masks in multimask mode
|
| 163 |
+
assert iou_pred.shape == (B, 3)
|
| 164 |
+
|
| 165 |
+
print(f"β
Mask decoder test passed - masks shape: {masks.shape}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TestSAM3:
|
| 169 |
+
"""Test complete SAM3 model"""
|
| 170 |
+
|
| 171 |
+
def test_sam3_initialization(self):
|
| 172 |
+
"""Test SAM3 model initialization"""
|
| 173 |
+
model = SAM3MLX()
|
| 174 |
+
|
| 175 |
+
assert model is not None
|
| 176 |
+
assert hasattr(model, 'vision_encoder')
|
| 177 |
+
assert hasattr(model, 'prompt_encoder')
|
| 178 |
+
assert hasattr(model, 'mask_decoder')
|
| 179 |
+
|
| 180 |
+
print("β
SAM3 initialization test passed")
|
| 181 |
+
|
| 182 |
+
def test_sam3_forward(self):
|
| 183 |
+
"""Test SAM3 forward pass"""
|
| 184 |
+
model = SAM3MLX()
|
| 185 |
+
|
| 186 |
+
# Create dummy inputs
|
| 187 |
+
image = mx.random.normal((1, 1024, 1024, 3))
|
| 188 |
+
point_coords = mx.array([[[512, 384]]]).astype(mx.float32)
|
| 189 |
+
point_labels = mx.array([[1]]).astype(mx.float32)
|
| 190 |
+
|
| 191 |
+
# Forward pass
|
| 192 |
+
result = model.predict(
|
| 193 |
+
image=image,
|
| 194 |
+
point_coords=point_coords,
|
| 195 |
+
point_labels=point_labels,
|
| 196 |
+
multimask_output=True,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Check outputs
|
| 200 |
+
assert "masks" in result
|
| 201 |
+
assert "iou_predictions" in result
|
| 202 |
+
|
| 203 |
+
masks = result["masks"]
|
| 204 |
+
iou_pred = result["iou_predictions"]
|
| 205 |
+
|
| 206 |
+
assert masks.shape[0] == 1 # batch
|
| 207 |
+
assert masks.shape[1] == 3 # 3 masks
|
| 208 |
+
assert iou_pred.shape == (1, 3)
|
| 209 |
+
|
| 210 |
+
print(f"β
SAM3 forward test passed")
|
| 211 |
+
print(f" Masks shape: {masks.shape}")
|
| 212 |
+
print(f" IoU predictions shape: {iou_pred.shape}")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
print("π§ͺ Running SAM3 MLX Tests\n")
|
| 217 |
+
print("=" * 60)
|
| 218 |
+
|
| 219 |
+
# Run tests
|
| 220 |
+
test_suite = [
|
| 221 |
+
("Attention Tests", TestAttention),
|
| 222 |
+
("Hiera Tests", TestHiera),
|
| 223 |
+
("Prompt Encoder Tests", TestPromptEncoder),
|
| 224 |
+
("Mask Decoder Tests", TestMaskDecoder),
|
| 225 |
+
("SAM3 Tests", TestSAM3),
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
passed = 0
|
| 229 |
+
failed = 0
|
| 230 |
+
|
| 231 |
+
for suite_name, test_class in test_suite:
|
| 232 |
+
print(f"\n{suite_name}")
|
| 233 |
+
print("-" * 60)
|
| 234 |
+
|
| 235 |
+
test_instance = test_class()
|
| 236 |
+
methods = [m for m in dir(test_instance) if m.startswith('test_')]
|
| 237 |
+
|
| 238 |
+
for method_name in methods:
|
| 239 |
+
try:
|
| 240 |
+
method = getattr(test_instance, method_name)
|
| 241 |
+
method()
|
| 242 |
+
passed += 1
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"β {method_name} failed: {e}")
|
| 245 |
+
failed += 1
|
| 246 |
+
|
| 247 |
+
print("\n" + "=" * 60)
|
| 248 |
+
print(f"Test Results: {passed} passed, {failed} failed")
|
| 249 |
+
|
| 250 |
+
if failed == 0:
|
| 251 |
+
print("β
All tests passed!")
|
| 252 |
+
exit(0)
|
| 253 |
+
else:
|
| 254 |
+
print(f"β {failed} tests failed")
|
| 255 |
+
exit(1)
|
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Weight Loading and Saving Utilities for SAM3 MLX
|
| 3 |
+
|
| 4 |
+
Handles:
|
| 5 |
+
- Loading converted MLX weights from .npz files
|
| 6 |
+
- Saving model weights
|
| 7 |
+
- Weight name mapping between PyTorch and MLX
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
import numpy as np
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, Any, Optional
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def map_pytorch_to_mlx_name(pytorch_name: str) -> str:
|
| 18 |
+
"""
|
| 19 |
+
Map PyTorch parameter names to MLX parameter names
|
| 20 |
+
|
| 21 |
+
PyTorch uses different naming conventions:
|
| 22 |
+
- weight/bias instead of MLX's weight/bias
|
| 23 |
+
- Different module paths
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
pytorch_name: PyTorch parameter name
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
MLX parameter name
|
| 30 |
+
"""
|
| 31 |
+
# Direct mappings
|
| 32 |
+
name = pytorch_name
|
| 33 |
+
|
| 34 |
+
# Vision encoder mappings
|
| 35 |
+
name = name.replace("image_encoder.", "vision_encoder.")
|
| 36 |
+
name = name.replace("trunk.", "")
|
| 37 |
+
|
| 38 |
+
# Attention mappings
|
| 39 |
+
name = name.replace("attn.qkv.", "attn.qkv.")
|
| 40 |
+
|
| 41 |
+
# Layer norm mappings (PyTorch uses weight/bias, MLX uses scale/bias)
|
| 42 |
+
# Actually MLX LayerNorm uses weight/bias too, so no change needed
|
| 43 |
+
|
| 44 |
+
# Prompt encoder mappings
|
| 45 |
+
name = name.replace("prompt_encoder.point_embeddings", "prompt_encoder.point_embeddings")
|
| 46 |
+
|
| 47 |
+
# Mask decoder mappings
|
| 48 |
+
name = name.replace("mask_decoder.transformer.", "mask_decoder.transformer.")
|
| 49 |
+
name = name.replace("mask_decoder.output_upscaling.", "mask_decoder.output_upscaling.")
|
| 50 |
+
|
| 51 |
+
return name
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_weights(
|
| 55 |
+
model: Any,
|
| 56 |
+
weights_path: str,
|
| 57 |
+
strict: bool = False,
|
| 58 |
+
verbose: bool = True,
|
| 59 |
+
) -> Any:
|
| 60 |
+
"""
|
| 61 |
+
Load MLX weights from .npz file into model
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
model: SAM3MLX model instance
|
| 65 |
+
weights_path: Path to .npz weights file
|
| 66 |
+
strict: If True, all parameters must match exactly
|
| 67 |
+
verbose: Print loading statistics
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Model with loaded weights
|
| 71 |
+
"""
|
| 72 |
+
weights_path = Path(weights_path)
|
| 73 |
+
|
| 74 |
+
if not weights_path.exists():
|
| 75 |
+
raise FileNotFoundError(f"Weights file not found: {weights_path}")
|
| 76 |
+
|
| 77 |
+
if verbose:
|
| 78 |
+
print(f"π₯ Loading weights from {weights_path.name}")
|
| 79 |
+
|
| 80 |
+
# Load numpy arrays
|
| 81 |
+
weights_np = np.load(weights_path)
|
| 82 |
+
|
| 83 |
+
# Get model parameter tree
|
| 84 |
+
model_params = model.parameters()
|
| 85 |
+
model_param_names = set(_flatten_params(model_params).keys())
|
| 86 |
+
|
| 87 |
+
# Convert and load weights
|
| 88 |
+
loaded_count = 0
|
| 89 |
+
skipped_count = 0
|
| 90 |
+
missing_params = set(model_param_names)
|
| 91 |
+
|
| 92 |
+
for param_name in weights_np.files:
|
| 93 |
+
# Map PyTorch name to MLX name
|
| 94 |
+
mlx_name = map_pytorch_to_mlx_name(param_name)
|
| 95 |
+
|
| 96 |
+
# Check if parameter exists in model
|
| 97 |
+
if mlx_name in model_param_names:
|
| 98 |
+
# Convert to MLX array
|
| 99 |
+
param_data = mx.array(weights_np[param_name])
|
| 100 |
+
|
| 101 |
+
# Set parameter in model
|
| 102 |
+
_set_param(model, mlx_name, param_data)
|
| 103 |
+
|
| 104 |
+
loaded_count += 1
|
| 105 |
+
missing_params.discard(mlx_name)
|
| 106 |
+
else:
|
| 107 |
+
skipped_count += 1
|
| 108 |
+
if verbose and strict:
|
| 109 |
+
print(f" β οΈ Skipped: {param_name} (not found in model)")
|
| 110 |
+
|
| 111 |
+
if verbose:
|
| 112 |
+
print(f"β
Loaded {loaded_count} parameters")
|
| 113 |
+
if skipped_count > 0:
|
| 114 |
+
print(f" βοΈ Skipped {skipped_count} parameters")
|
| 115 |
+
if len(missing_params) > 0:
|
| 116 |
+
print(f" β Missing {len(missing_params)} parameters in checkpoint")
|
| 117 |
+
if strict:
|
| 118 |
+
for param in list(missing_params)[:10]: # Show first 10
|
| 119 |
+
print(f" - {param}")
|
| 120 |
+
|
| 121 |
+
if strict and len(missing_params) > 0:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"Missing {len(missing_params)} parameters in checkpoint. "
|
| 124 |
+
"Use strict=False to load partial weights."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return model
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def save_weights(
|
| 131 |
+
model: Any,
|
| 132 |
+
weights_path: str,
|
| 133 |
+
verbose: bool = True,
|
| 134 |
+
) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Save model weights to .npz file
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model: SAM3MLX model instance
|
| 140 |
+
weights_path: Path to save .npz weights file
|
| 141 |
+
verbose: Print saving statistics
|
| 142 |
+
"""
|
| 143 |
+
weights_path = Path(weights_path)
|
| 144 |
+
weights_path.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
|
| 146 |
+
if verbose:
|
| 147 |
+
print(f"πΎ Saving weights to {weights_path.name}")
|
| 148 |
+
|
| 149 |
+
# Get model parameters
|
| 150 |
+
model_params = _flatten_params(model.parameters())
|
| 151 |
+
|
| 152 |
+
# Convert to numpy
|
| 153 |
+
weights_np = {}
|
| 154 |
+
for name, param in model_params.items():
|
| 155 |
+
weights_np[name] = np.array(param)
|
| 156 |
+
|
| 157 |
+
# Save
|
| 158 |
+
np.savez(weights_path, **weights_np)
|
| 159 |
+
|
| 160 |
+
if verbose:
|
| 161 |
+
file_size_mb = weights_path.stat().st_size / (1024 * 1024)
|
| 162 |
+
print(f"β
Saved {len(weights_np)} parameters ({file_size_mb:.2f} MB)")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _flatten_params(params: Dict, prefix: str = "", sep: str = ".") -> Dict[str, mx.array]:
|
| 166 |
+
"""
|
| 167 |
+
Flatten nested parameter dictionary
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
params: Nested parameter dict
|
| 171 |
+
prefix: Current prefix for parameter names
|
| 172 |
+
sep: Separator for parameter names
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Flattened dict of {name: array}
|
| 176 |
+
"""
|
| 177 |
+
flat = {}
|
| 178 |
+
|
| 179 |
+
for key, value in params.items():
|
| 180 |
+
full_key = f"{prefix}{sep}{key}" if prefix else key
|
| 181 |
+
|
| 182 |
+
if isinstance(value, dict):
|
| 183 |
+
# Recurse into nested dict
|
| 184 |
+
flat.update(_flatten_params(value, full_key, sep))
|
| 185 |
+
elif isinstance(value, mx.array):
|
| 186 |
+
# Leaf parameter
|
| 187 |
+
flat[full_key] = value
|
| 188 |
+
elif isinstance(value, list):
|
| 189 |
+
# List of parameters (e.g., from nn.Sequential)
|
| 190 |
+
for i, item in enumerate(value):
|
| 191 |
+
if isinstance(item, dict):
|
| 192 |
+
flat.update(_flatten_params(item, f"{full_key}.{i}", sep))
|
| 193 |
+
elif isinstance(item, mx.array):
|
| 194 |
+
flat[f"{full_key}.{i}"] = item
|
| 195 |
+
|
| 196 |
+
return flat
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _set_param(model: Any, param_name: str, value: mx.array) -> None:
|
| 200 |
+
"""
|
| 201 |
+
Set a parameter in the model by dotted name
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
model: Model instance
|
| 205 |
+
param_name: Dotted parameter name (e.g., "vision_encoder.patch_embed.proj.weight")
|
| 206 |
+
value: Parameter value
|
| 207 |
+
"""
|
| 208 |
+
parts = param_name.split(".")
|
| 209 |
+
obj = model
|
| 210 |
+
|
| 211 |
+
# Navigate to the parent object
|
| 212 |
+
for part in parts[:-1]:
|
| 213 |
+
if part.isdigit():
|
| 214 |
+
# List index
|
| 215 |
+
obj = obj[int(part)]
|
| 216 |
+
elif hasattr(obj, part):
|
| 217 |
+
obj = getattr(obj, part)
|
| 218 |
+
else:
|
| 219 |
+
# Try to access as attribute
|
| 220 |
+
raise AttributeError(f"Cannot find {part} in {type(obj)}")
|
| 221 |
+
|
| 222 |
+
# Set the final attribute
|
| 223 |
+
final_attr = parts[-1]
|
| 224 |
+
if hasattr(obj, final_attr):
|
| 225 |
+
setattr(obj, final_attr, value)
|
| 226 |
+
else:
|
| 227 |
+
raise AttributeError(f"Cannot set {final_attr} in {type(obj)}")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
| 231 |
+
"""
|
| 232 |
+
Load model configuration from JSON file
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
config_path: Path to config JSON file
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Configuration dictionary
|
| 239 |
+
"""
|
| 240 |
+
config_path = Path(config_path)
|
| 241 |
+
|
| 242 |
+
if not config_path.exists():
|
| 243 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 244 |
+
|
| 245 |
+
with open(config_path) as f:
|
| 246 |
+
config = json.load(f)
|
| 247 |
+
|
| 248 |
+
return config
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def save_config(config: Dict[str, Any], config_path: str) -> None:
|
| 252 |
+
"""
|
| 253 |
+
Save model configuration to JSON file
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
config: Configuration dictionary
|
| 257 |
+
config_path: Path to save config JSON file
|
| 258 |
+
"""
|
| 259 |
+
config_path = Path(config_path)
|
| 260 |
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
with open(config_path, 'w') as f:
|
| 263 |
+
json.dump(config, f, indent=2)
|