| |
| """ |
| SAM3 MLX Click Segmentation Example |
| |
| Demonstrates how to: |
| 1. Load SAM3 MLX model |
| 2. Process an image |
| 3. Segment objects with point clicks |
| 4. Visualize results |
| |
| Usage: |
| python click_segment.py --image path/to/image.jpg --point 100,200 |
| """ |
|
|
| import argparse |
| import time |
| from pathlib import Path |
| from typing import Tuple, Optional |
| import numpy as np |
| import mlx.core as mx |
|
|
| try: |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("โ Please install PIL and matplotlib:") |
| print(" pip install pillow matplotlib") |
| exit(1) |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from models.sam3 import SAM3MLX |
| from utils.weights import load_weights |
|
|
|
|
| def load_image(image_path: str, target_size: int = 1024) -> Tuple[mx.array, np.ndarray]: |
| """ |
| Load and preprocess image for SAM3 |
| |
| Args: |
| image_path: Path to image file |
| target_size: Target image size (SAM3 uses 1024x1024) |
| |
| Returns: |
| Tuple of (preprocessed MLX array, original numpy array) |
| """ |
| |
| img = Image.open(image_path).convert("RGB") |
| original = np.array(img) |
|
|
| |
| img_resized = img.resize((target_size, target_size), Image.BILINEAR) |
| img_np = np.array(img_resized).astype(np.float32) / 255.0 |
|
|
| |
| img_mlx = mx.array(img_np).reshape(1, target_size, target_size, 3) |
|
|
| return img_mlx, original |
|
|
|
|
| def visualize_prediction( |
| image: np.ndarray, |
| masks: mx.array, |
| point_coords: mx.array, |
| point_labels: mx.array, |
| iou_scores: mx.array, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Visualize segmentation results |
| |
| Args: |
| image: Original image (H, W, 3) |
| masks: Predicted masks (1, num_masks, H, W) |
| point_coords: Input point coordinates (1, N, 2) |
| point_labels: Input point labels (1, N) |
| iou_scores: IoU quality scores (1, num_masks) |
| save_path: Optional path to save visualization |
| """ |
| |
| masks_np = np.array(masks[0]) |
| point_coords_np = np.array(point_coords[0]) |
| point_labels_np = np.array(point_labels[0]) |
| iou_scores_np = np.array(iou_scores[0]) |
|
|
| num_masks = masks_np.shape[0] |
|
|
| |
| fig, axes = plt.subplots(1, num_masks + 1, figsize=(5 * (num_masks + 1), 5)) |
| if num_masks == 1: |
| axes = [axes[0], axes[1]] |
|
|
| |
| axes[0].imshow(image) |
| axes[0].set_title("Input Image with Points") |
|
|
| |
| for coord, label in zip(point_coords_np, point_labels_np): |
| color = 'g' if label == 1 else 'r' |
| marker = 'o' if label == 1 else 'x' |
| axes[0].scatter(coord[0], coord[1], c=color, marker=marker, s=200, linewidths=3) |
|
|
| axes[0].axis('off') |
|
|
| |
| for i in range(num_masks): |
| |
| mask = masks_np[i] |
| H, W = image.shape[:2] |
| from PIL import Image as PILImage |
| mask_resized = PILImage.fromarray((mask * 255).astype(np.uint8)) |
| mask_resized = mask_resized.resize((W, H), PILImage.BILINEAR) |
| mask_resized = np.array(mask_resized) / 255.0 |
|
|
| |
| overlay = image.copy() |
| mask_3ch = np.stack([mask_resized] * 3, axis=-1) |
| overlay = (overlay * (1 - mask_3ch * 0.5) + np.array([0, 255, 0]) * mask_3ch * 0.5).astype(np.uint8) |
|
|
| axes[i + 1].imshow(overlay) |
| axes[i + 1].set_title(f"Mask {i+1} (IoU: {iou_scores_np[i]:.3f})") |
| axes[i + 1].axis('off') |
|
|
| plt.tight_layout() |
|
|
| if save_path: |
| plt.savefig(save_path, bbox_inches='tight', dpi=150) |
| print(f"๐พ Saved visualization to {save_path}") |
|
|
| plt.show() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="SAM3 MLX Click Segmentation Example") |
| parser.add_argument("--image", type=str, required=True, help="Path to input image") |
| parser.add_argument( |
| "--point", |
| type=str, |
| action="append", |
| help="Click point as 'x,y' (can specify multiple). Use +x,y for positive, -x,y for negative", |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| default="./checkpoints/sam3_mlx", |
| help="Path to SAM3 MLX checkpoint directory", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| default=None, |
| help="Path to save output visualization", |
| ) |
| parser.add_argument( |
| "--single-mask", |
| action="store_true", |
| help="Output single mask instead of 3 masks", |
| ) |
| args = parser.parse_args() |
|
|
| print("๐ SAM3 MLX Click Segmentation Example") |
| print("=" * 60) |
|
|
| |
| if not args.point: |
| print("โ Please specify at least one point with --point x,y") |
| return |
|
|
| point_coords_list = [] |
| point_labels_list = [] |
|
|
| for point_str in args.point: |
| |
| if point_str.startswith('+'): |
| label = 1 |
| point_str = point_str[1:] |
| elif point_str.startswith('-'): |
| label = 0 |
| point_str = point_str[1:] |
| else: |
| label = 1 |
|
|
| x, y = map(float, point_str.split(',')) |
| point_coords_list.append([x, y]) |
| point_labels_list.append(label) |
|
|
| point_coords = mx.array(point_coords_list).reshape(1, -1, 2) |
| point_labels = mx.array(point_labels_list).reshape(1, -1) |
|
|
| print(f"๐ Input points: {len(point_coords_list)}") |
| for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)): |
| label_str = "positive" if label == 1 else "negative" |
| print(f" Point {i+1}: ({coord[0]:.0f}, {coord[1]:.0f}) [{label_str}]") |
|
|
| |
| print(f"\n๐ธ Loading image: {args.image}") |
| image_mlx, image_original = load_image(args.image) |
| print(f" Image size: {image_original.shape[1]}x{image_original.shape[0]}") |
|
|
| |
| print(f"\n๐๏ธ Initializing SAM3 MLX model...") |
| model = SAM3MLX() |
|
|
| |
| checkpoint_dir = Path(args.checkpoint) |
| weights_path = checkpoint_dir / "sam3_mlx_weights.npz" |
|
|
| if weights_path.exists(): |
| print(f"\n๐ฅ Loading weights from {checkpoint_dir}") |
| model = load_weights(model, str(weights_path), strict=False, verbose=True) |
| else: |
| print(f"\nโ ๏ธ Weights not found at {weights_path}") |
| print(" Using randomly initialized model (for testing architecture only)") |
|
|
| |
| print(f"\n๐ฏ Running segmentation...") |
| start_time = time.time() |
|
|
| result = model.predict( |
| image=image_mlx, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=not args.single_mask, |
| ) |
|
|
| |
| mx.eval(result["masks"]) |
|
|
| inference_time = (time.time() - start_time) * 1000 |
| print(f"โ
Inference completed in {inference_time:.1f}ms") |
|
|
| |
| masks = result["masks"] |
| iou_predictions = result["iou_predictions"] |
|
|
| print(f"\n๐ Results:") |
| print(f" Number of masks: {masks.shape[1]}") |
| print(f" Mask resolution: {masks.shape[2]}x{masks.shape[3]}") |
| print(f" IoU scores: {np.array(iou_predictions[0])}") |
|
|
| |
| print(f"\n๐จ Visualizing results...") |
| visualize_prediction( |
| image_original, |
| masks, |
| point_coords, |
| point_labels, |
| iou_predictions, |
| save_path=args.output, |
| ) |
|
|
| print(f"\nโ
Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|