MLX
MLX_SAM3 / weights.py
Hoodrobot's picture
Upload 15 files
ced11e2 verified
"""
Weight Loading and Saving Utilities for SAM3 MLX
Handles:
- Loading converted MLX weights from .npz files
- Saving model weights
- Weight name mapping between PyTorch and MLX
"""
import mlx.core as mx
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
import json
def map_pytorch_to_mlx_name(pytorch_name: str) -> str:
"""
Map PyTorch parameter names to MLX parameter names
PyTorch uses different naming conventions:
- weight/bias instead of MLX's weight/bias
- Different module paths
Args:
pytorch_name: PyTorch parameter name
Returns:
MLX parameter name
"""
# Direct mappings
name = pytorch_name
# Vision encoder mappings
name = name.replace("image_encoder.", "vision_encoder.")
name = name.replace("trunk.", "")
# Attention mappings
name = name.replace("attn.qkv.", "attn.qkv.")
# Layer norm mappings (PyTorch uses weight/bias, MLX uses scale/bias)
# Actually MLX LayerNorm uses weight/bias too, so no change needed
# Prompt encoder mappings
name = name.replace("prompt_encoder.point_embeddings", "prompt_encoder.point_embeddings")
# Mask decoder mappings
name = name.replace("mask_decoder.transformer.", "mask_decoder.transformer.")
name = name.replace("mask_decoder.output_upscaling.", "mask_decoder.output_upscaling.")
return name
def load_weights(
model: Any,
weights_path: str,
strict: bool = False,
verbose: bool = True,
) -> Any:
"""
Load MLX weights from .npz file into model
Args:
model: SAM3MLX model instance
weights_path: Path to .npz weights file
strict: If True, all parameters must match exactly
verbose: Print loading statistics
Returns:
Model with loaded weights
"""
weights_path = Path(weights_path)
if not weights_path.exists():
raise FileNotFoundError(f"Weights file not found: {weights_path}")
if verbose:
print(f"📥 Loading weights from {weights_path.name}")
# Load numpy arrays
weights_np = np.load(weights_path)
# Get model parameter tree
model_params = model.parameters()
model_param_names = set(_flatten_params(model_params).keys())
# Convert and load weights
loaded_count = 0
skipped_count = 0
missing_params = set(model_param_names)
for param_name in weights_np.files:
# Map PyTorch name to MLX name
mlx_name = map_pytorch_to_mlx_name(param_name)
# Check if parameter exists in model
if mlx_name in model_param_names:
# Convert to MLX array
param_data = mx.array(weights_np[param_name])
# Set parameter in model
_set_param(model, mlx_name, param_data)
loaded_count += 1
missing_params.discard(mlx_name)
else:
skipped_count += 1
if verbose and strict:
print(f" ⚠️ Skipped: {param_name} (not found in model)")
if verbose:
print(f"✅ Loaded {loaded_count} parameters")
if skipped_count > 0:
print(f" ⏭️ Skipped {skipped_count} parameters")
if len(missing_params) > 0:
print(f" ❌ Missing {len(missing_params)} parameters in checkpoint")
if strict:
for param in list(missing_params)[:10]: # Show first 10
print(f" - {param}")
if strict and len(missing_params) > 0:
raise ValueError(
f"Missing {len(missing_params)} parameters in checkpoint. "
"Use strict=False to load partial weights."
)
return model
def save_weights(
model: Any,
weights_path: str,
verbose: bool = True,
) -> None:
"""
Save model weights to .npz file
Args:
model: SAM3MLX model instance
weights_path: Path to save .npz weights file
verbose: Print saving statistics
"""
weights_path = Path(weights_path)
weights_path.parent.mkdir(parents=True, exist_ok=True)
if verbose:
print(f"💾 Saving weights to {weights_path.name}")
# Get model parameters
model_params = _flatten_params(model.parameters())
# Convert to numpy
weights_np = {}
for name, param in model_params.items():
weights_np[name] = np.array(param)
# Save
np.savez(weights_path, **weights_np)
if verbose:
file_size_mb = weights_path.stat().st_size / (1024 * 1024)
print(f"✅ Saved {len(weights_np)} parameters ({file_size_mb:.2f} MB)")
def _flatten_params(params: Dict, prefix: str = "", sep: str = ".") -> Dict[str, mx.array]:
"""
Flatten nested parameter dictionary
Args:
params: Nested parameter dict
prefix: Current prefix for parameter names
sep: Separator for parameter names
Returns:
Flattened dict of {name: array}
"""
flat = {}
for key, value in params.items():
full_key = f"{prefix}{sep}{key}" if prefix else key
if isinstance(value, dict):
# Recurse into nested dict
flat.update(_flatten_params(value, full_key, sep))
elif isinstance(value, mx.array):
# Leaf parameter
flat[full_key] = value
elif isinstance(value, list):
# List of parameters (e.g., from nn.Sequential)
for i, item in enumerate(value):
if isinstance(item, dict):
flat.update(_flatten_params(item, f"{full_key}.{i}", sep))
elif isinstance(item, mx.array):
flat[f"{full_key}.{i}"] = item
return flat
def _set_param(model: Any, param_name: str, value: mx.array) -> None:
"""
Set a parameter in the model by dotted name
Args:
model: Model instance
param_name: Dotted parameter name (e.g., "vision_encoder.patch_embed.proj.weight")
value: Parameter value
"""
parts = param_name.split(".")
obj = model
# Navigate to the parent object
for part in parts[:-1]:
if part.isdigit():
# List index
obj = obj[int(part)]
elif hasattr(obj, part):
obj = getattr(obj, part)
else:
# Try to access as attribute
raise AttributeError(f"Cannot find {part} in {type(obj)}")
# Set the final attribute
final_attr = parts[-1]
if hasattr(obj, final_attr):
setattr(obj, final_attr, value)
else:
raise AttributeError(f"Cannot set {final_attr} in {type(obj)}")
def load_config(config_path: str) -> Dict[str, Any]:
"""
Load model configuration from JSON file
Args:
config_path: Path to config JSON file
Returns:
Configuration dictionary
"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path) as f:
config = json.load(f)
return config
def save_config(config: Dict[str, Any], config_path: str) -> None:
"""
Save model configuration to JSON file
Args:
config: Configuration dictionary
config_path: Path to save config JSON file
"""
config_path = Path(config_path)
config_path.parent.mkdir(parents=True, exist_ok=True)
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)