| """ |
| Weight Loading and Saving Utilities for SAM3 MLX |
| |
| Handles: |
| - Loading converted MLX weights from .npz files |
| - Saving model weights |
| - Weight name mapping between PyTorch and MLX |
| """ |
|
|
| import mlx.core as mx |
| import numpy as np |
| from pathlib import Path |
| from typing import Dict, Any, Optional |
| import json |
|
|
|
|
| def map_pytorch_to_mlx_name(pytorch_name: str) -> str: |
| """ |
| Map PyTorch parameter names to MLX parameter names |
| |
| PyTorch uses different naming conventions: |
| - weight/bias instead of MLX's weight/bias |
| - Different module paths |
| |
| Args: |
| pytorch_name: PyTorch parameter name |
| |
| Returns: |
| MLX parameter name |
| """ |
| |
| name = pytorch_name |
|
|
| |
| name = name.replace("image_encoder.", "vision_encoder.") |
| name = name.replace("trunk.", "") |
|
|
| |
| name = name.replace("attn.qkv.", "attn.qkv.") |
|
|
| |
| |
|
|
| |
| name = name.replace("prompt_encoder.point_embeddings", "prompt_encoder.point_embeddings") |
|
|
| |
| name = name.replace("mask_decoder.transformer.", "mask_decoder.transformer.") |
| name = name.replace("mask_decoder.output_upscaling.", "mask_decoder.output_upscaling.") |
|
|
| return name |
|
|
|
|
| def load_weights( |
| model: Any, |
| weights_path: str, |
| strict: bool = False, |
| verbose: bool = True, |
| ) -> Any: |
| """ |
| Load MLX weights from .npz file into model |
| |
| Args: |
| model: SAM3MLX model instance |
| weights_path: Path to .npz weights file |
| strict: If True, all parameters must match exactly |
| verbose: Print loading statistics |
| |
| Returns: |
| Model with loaded weights |
| """ |
| weights_path = Path(weights_path) |
|
|
| if not weights_path.exists(): |
| raise FileNotFoundError(f"Weights file not found: {weights_path}") |
|
|
| if verbose: |
| print(f"📥 Loading weights from {weights_path.name}") |
|
|
| |
| weights_np = np.load(weights_path) |
|
|
| |
| model_params = model.parameters() |
| model_param_names = set(_flatten_params(model_params).keys()) |
|
|
| |
| loaded_count = 0 |
| skipped_count = 0 |
| missing_params = set(model_param_names) |
|
|
| for param_name in weights_np.files: |
| |
| mlx_name = map_pytorch_to_mlx_name(param_name) |
|
|
| |
| if mlx_name in model_param_names: |
| |
| param_data = mx.array(weights_np[param_name]) |
|
|
| |
| _set_param(model, mlx_name, param_data) |
|
|
| loaded_count += 1 |
| missing_params.discard(mlx_name) |
| else: |
| skipped_count += 1 |
| if verbose and strict: |
| print(f" ⚠️ Skipped: {param_name} (not found in model)") |
|
|
| if verbose: |
| print(f"✅ Loaded {loaded_count} parameters") |
| if skipped_count > 0: |
| print(f" ⏭️ Skipped {skipped_count} parameters") |
| if len(missing_params) > 0: |
| print(f" ❌ Missing {len(missing_params)} parameters in checkpoint") |
| if strict: |
| for param in list(missing_params)[:10]: |
| print(f" - {param}") |
|
|
| if strict and len(missing_params) > 0: |
| raise ValueError( |
| f"Missing {len(missing_params)} parameters in checkpoint. " |
| "Use strict=False to load partial weights." |
| ) |
|
|
| return model |
|
|
|
|
| def save_weights( |
| model: Any, |
| weights_path: str, |
| verbose: bool = True, |
| ) -> None: |
| """ |
| Save model weights to .npz file |
| |
| Args: |
| model: SAM3MLX model instance |
| weights_path: Path to save .npz weights file |
| verbose: Print saving statistics |
| """ |
| weights_path = Path(weights_path) |
| weights_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if verbose: |
| print(f"💾 Saving weights to {weights_path.name}") |
|
|
| |
| model_params = _flatten_params(model.parameters()) |
|
|
| |
| weights_np = {} |
| for name, param in model_params.items(): |
| weights_np[name] = np.array(param) |
|
|
| |
| np.savez(weights_path, **weights_np) |
|
|
| if verbose: |
| file_size_mb = weights_path.stat().st_size / (1024 * 1024) |
| print(f"✅ Saved {len(weights_np)} parameters ({file_size_mb:.2f} MB)") |
|
|
|
|
| def _flatten_params(params: Dict, prefix: str = "", sep: str = ".") -> Dict[str, mx.array]: |
| """ |
| Flatten nested parameter dictionary |
| |
| Args: |
| params: Nested parameter dict |
| prefix: Current prefix for parameter names |
| sep: Separator for parameter names |
| |
| Returns: |
| Flattened dict of {name: array} |
| """ |
| flat = {} |
|
|
| for key, value in params.items(): |
| full_key = f"{prefix}{sep}{key}" if prefix else key |
|
|
| if isinstance(value, dict): |
| |
| flat.update(_flatten_params(value, full_key, sep)) |
| elif isinstance(value, mx.array): |
| |
| flat[full_key] = value |
| elif isinstance(value, list): |
| |
| for i, item in enumerate(value): |
| if isinstance(item, dict): |
| flat.update(_flatten_params(item, f"{full_key}.{i}", sep)) |
| elif isinstance(item, mx.array): |
| flat[f"{full_key}.{i}"] = item |
|
|
| return flat |
|
|
|
|
| def _set_param(model: Any, param_name: str, value: mx.array) -> None: |
| """ |
| Set a parameter in the model by dotted name |
| |
| Args: |
| model: Model instance |
| param_name: Dotted parameter name (e.g., "vision_encoder.patch_embed.proj.weight") |
| value: Parameter value |
| """ |
| parts = param_name.split(".") |
| obj = model |
|
|
| |
| for part in parts[:-1]: |
| if part.isdigit(): |
| |
| obj = obj[int(part)] |
| elif hasattr(obj, part): |
| obj = getattr(obj, part) |
| else: |
| |
| raise AttributeError(f"Cannot find {part} in {type(obj)}") |
|
|
| |
| final_attr = parts[-1] |
| if hasattr(obj, final_attr): |
| setattr(obj, final_attr, value) |
| else: |
| raise AttributeError(f"Cannot set {final_attr} in {type(obj)}") |
|
|
|
|
| def load_config(config_path: str) -> Dict[str, Any]: |
| """ |
| Load model configuration from JSON file |
| |
| Args: |
| config_path: Path to config JSON file |
| |
| Returns: |
| Configuration dictionary |
| """ |
| config_path = Path(config_path) |
|
|
| if not config_path.exists(): |
| raise FileNotFoundError(f"Config file not found: {config_path}") |
|
|
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| return config |
|
|
|
|
| def save_config(config: Dict[str, Any], config_path: str) -> None: |
| """ |
| Save model configuration to JSON file |
| |
| Args: |
| config: Configuration dictionary |
| config_path: Path to save config JSON file |
| """ |
| config_path = Path(config_path) |
| config_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| with open(config_path, 'w') as f: |
| json.dump(config, f, indent=2) |
|
|