| """ |
| SAM (Segment Anything Model) Integration for Pneumonia Consolidation |
| This script uses Meta's Segment Anything Model to generate initial segmentation masks |
| that can be refined manually. |
| """ |
|
|
| import numpy as np |
| import cv2 |
| from pathlib import Path |
| import matplotlib.pyplot as plt |
| import argparse |
|
|
|
|
| def setup_sam(): |
| """ |
| Setup SAM model. Install with: |
| pip install segment-anything |
| |
| Download checkpoint from: |
| https://github.com/facebookresearch/segment-anything#model-checkpoints |
| """ |
| try: |
| from segment_anything import sam_model_registry, SamPredictor |
| return sam_model_registry, SamPredictor |
| except ImportError: |
| print("Error: segment-anything not installed.") |
| print("Install with: pip install segment-anything") |
| print("Then download a model checkpoint from:") |
| print("https://github.com/facebookresearch/segment-anything#model-checkpoints") |
| return None, None |
|
|
|
|
| def initialize_sam_predictor(checkpoint_path, model_type="vit_h"): |
| """ |
| Initialize SAM predictor. |
| |
| Args: |
| checkpoint_path: Path to SAM checkpoint (.pth file) |
| model_type: Model type ('vit_h', 'vit_l', or 'vit_b') |
| |
| Returns: |
| SAM predictor object |
| """ |
| sam_model_registry, SamPredictor = setup_sam() |
| if sam_model_registry is None: |
| return None |
| |
| sam = sam_model_registry[model_type](checkpoint=checkpoint_path) |
| predictor = SamPredictor(sam) |
| |
| return predictor |
|
|
|
|
| def predict_consolidation_with_points(image_path, predictor, point_coords, point_labels): |
| """ |
| Generate segmentation mask using point prompts. |
| |
| Args: |
| image_path: Path to chest X-ray image |
| predictor: SAM predictor object |
| point_coords: Array of [x, y] coordinates for prompts |
| point_labels: Array of labels (1 for positive/include, 0 for negative/exclude) |
| |
| Returns: |
| mask: Binary segmentation mask |
| scores: Confidence scores for each mask |
| """ |
| |
| image = cv2.imread(str(image_path)) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| predictor.set_image(image) |
| |
| |
| point_coords = np.array(point_coords) |
| point_labels = np.array(point_labels) |
| |
| |
| masks, scores, logits = predictor.predict( |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=True |
| ) |
| |
| return masks, scores, image |
|
|
|
|
| def predict_consolidation_with_box(image_path, predictor, box_coords): |
| """ |
| Generate segmentation mask using bounding box prompt. |
| |
| Args: |
| image_path: Path to chest X-ray image |
| predictor: SAM predictor object |
| box_coords: [x1, y1, x2, y2] bounding box coordinates |
| |
| Returns: |
| mask: Binary segmentation mask |
| """ |
| |
| image = cv2.imread(str(image_path)) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| |
| predictor.set_image(image) |
| |
| |
| box = np.array(box_coords) |
| |
| |
| masks, scores, logits = predictor.predict( |
| box=box, |
| multimask_output=True |
| ) |
| |
| return masks, scores, image |
|
|
|
|
| def automatic_consolidation_detection(image_path, predictor, grid_size=5): |
| """ |
| Automatically detect potential consolidation regions using grid-based sampling. |
| |
| Args: |
| image_path: Path to chest X-ray image |
| predictor: SAM predictor object |
| grid_size: Number of points in grid (grid_size x grid_size) |
| |
| Returns: |
| Combined mask from multiple detections |
| """ |
| |
| image = cv2.imread(str(image_path)) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| h, w = image.shape[:2] |
| |
| |
| predictor.set_image(image) |
| |
| |
| margin_h = int(h * 0.1) |
| margin_w = int(w * 0.2) |
| |
| x_coords = np.linspace(margin_w, w - margin_w, grid_size) |
| y_coords = np.linspace(margin_h, h - margin_h, grid_size) |
| |
| all_masks = [] |
| |
| for x in x_coords: |
| for y in y_coords: |
| point = np.array([[x, y]]) |
| label = np.array([1]) |
| |
| try: |
| masks, scores, _ = predictor.predict( |
| point_coords=point, |
| point_labels=label, |
| multimask_output=False |
| ) |
| |
| |
| if scores[0] > 0.8: |
| all_masks.append(masks[0]) |
| except Exception as e: |
| continue |
| |
| if not all_masks: |
| return None, image |
| |
| |
| combined_mask = np.any(all_masks, axis=0).astype(np.uint8) |
| |
| return combined_mask, image |
|
|
|
|
| def visualize_sam_results(image, masks, scores, point_coords=None, save_path=None): |
| """ |
| Visualize SAM segmentation results. |
| |
| Args: |
| image: Original image |
| masks: Array of masks |
| scores: Confidence scores |
| point_coords: Optional point prompts to display |
| save_path: Optional path to save visualization |
| """ |
| fig, axes = plt.subplots(1, len(masks) + 1, figsize=(15, 5)) |
| |
| |
| axes[0].imshow(image) |
| axes[0].set_title('Original') |
| axes[0].axis('off') |
| |
| if point_coords is not None: |
| axes[0].scatter(point_coords[:, 0], point_coords[:, 1], |
| c='red', s=100, marker='*') |
| |
| |
| for idx, (mask, score) in enumerate(zip(masks, scores)): |
| axes[idx + 1].imshow(image) |
| axes[idx + 1].imshow(mask, alpha=0.5, cmap='jet') |
| axes[idx + 1].set_title(f'Mask {idx + 1}\nScore: {score:.3f}') |
| axes[idx + 1].axis('off') |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Visualization saved to: {save_path}") |
| |
| plt.show() |
|
|
|
|
| def save_mask(mask, output_path): |
| """Save binary mask as image.""" |
| mask_uint8 = (mask * 255).astype(np.uint8) |
| cv2.imwrite(str(output_path), mask_uint8) |
| print(f"Mask saved to: {output_path}") |
|
|
|
|
| def interactive_sam_segmentation(image_path, checkpoint_path): |
| """ |
| Interactive segmentation where user clicks points to guide SAM. |
| This is a simple CLI version - for GUI, integrate with Streamlit. |
| """ |
| print("Initializing SAM...") |
| predictor = initialize_sam_predictor(checkpoint_path) |
| |
| if predictor is None: |
| return |
| |
| |
| image = cv2.imread(str(image_path)) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| print("\nInstructions:") |
| print("1. The image will be displayed") |
| print("2. Click on consolidation areas (left click)") |
| print("3. Click on background areas to exclude (right click)") |
| print("4. Press 'q' when done") |
| print("5. Choose best mask from results") |
| |
| point_coords = [] |
| point_labels = [] |
| |
| def mouse_callback(event, x, y, flags, param): |
| if event == cv2.EVENT_LBUTTONDOWN: |
| point_coords.append([x, y]) |
| point_labels.append(1) |
| print(f"Added positive point at ({x}, {y})") |
| elif event == cv2.EVENT_RBUTTONDOWN: |
| point_coords.append([x, y]) |
| point_labels.append(0) |
| print(f"Added negative point at ({x}, {y})") |
| |
| |
| display_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| cv2.namedWindow('Image') |
| cv2.setMouseCallback('Image', mouse_callback) |
| |
| while True: |
| display = display_img.copy() |
| |
| |
| for coord, label in zip(point_coords, point_labels): |
| color = (0, 255, 0) if label == 1 else (0, 0, 255) |
| cv2.circle(display, tuple(coord), 5, color, -1) |
| |
| cv2.imshow('Image', display) |
| |
| key = cv2.waitKey(1) & 0xFF |
| if key == ord('q'): |
| break |
| |
| cv2.destroyAllWindows() |
| |
| if point_coords: |
| print("\nGenerating masks...") |
| masks, scores, _ = predict_consolidation_with_points( |
| image_path, predictor, point_coords, point_labels |
| ) |
| |
| |
| visualize_sam_results(image, masks, scores, np.array(point_coords)) |
| |
| |
| best_idx = np.argmax(scores) |
| output_path = Path(image_path).parent / f"{Path(image_path).stem}_sam_mask.png" |
| save_mask(masks[best_idx], output_path) |
| |
| return masks[best_idx] |
| |
| return None |
|
|
|
|
| def batch_process_with_sam(input_dir, output_dir, checkpoint_path, mode='auto'): |
| """ |
| Batch process images with SAM. |
| |
| Args: |
| input_dir: Directory with chest X-ray images |
| output_dir: Directory to save masks |
| checkpoint_path: Path to SAM checkpoint |
| mode: 'auto' for automatic or 'center' for single center point |
| """ |
| input_path = Path(input_dir) |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| print("Initializing SAM...") |
| predictor = initialize_sam_predictor(checkpoint_path) |
| |
| if predictor is None: |
| return |
| |
| images = list(input_path.glob("*.jpg")) + list(input_path.glob("*.png")) |
| print(f"Found {len(images)} images to process") |
| |
| for img_path in images: |
| print(f"\nProcessing: {img_path.name}") |
| |
| try: |
| if mode == 'auto': |
| mask, image = automatic_consolidation_detection(img_path, predictor) |
| else: |
| |
| image = cv2.imread(str(img_path)) |
| h, w = image.shape[:2] |
| center_point = [[w // 2, h // 2]] |
| masks, scores, image = predict_consolidation_with_points( |
| img_path, predictor, center_point, [1] |
| ) |
| mask = masks[np.argmax(scores)] |
| |
| if mask is not None: |
| output_file = output_path / f"{img_path.stem}_mask.png" |
| save_mask(mask, output_file) |
| else: |
| print(f"No mask generated for {img_path.name}") |
| |
| except Exception as e: |
| print(f"Error processing {img_path.name}: {e}") |
| |
| print(f"\nBatch processing complete! Masks saved to: {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Generate pneumonia consolidation masks using SAM" |
| ) |
| parser.add_argument( |
| '--checkpoint', |
| type=str, |
| required=True, |
| help='Path to SAM checkpoint file (.pth)' |
| ) |
| parser.add_argument( |
| '--image', |
| type=str, |
| help='Path to single image (for interactive mode)' |
| ) |
| parser.add_argument( |
| '--input_dir', |
| type=str, |
| help='Input directory for batch processing' |
| ) |
| parser.add_argument( |
| '--output_dir', |
| type=str, |
| help='Output directory for batch processing' |
| ) |
| parser.add_argument( |
| '--mode', |
| type=str, |
| default='interactive', |
| choices=['interactive', 'auto', 'center'], |
| help='Processing mode' |
| ) |
| parser.add_argument( |
| '--model_type', |
| type=str, |
| default='vit_h', |
| choices=['vit_h', 'vit_l', 'vit_b'], |
| help='SAM model type' |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.mode == 'interactive' and args.image: |
| interactive_sam_segmentation(args.image, args.checkpoint) |
| elif args.input_dir and args.output_dir: |
| batch_process_with_sam(args.input_dir, args.output_dir, |
| args.checkpoint, args.mode) |
| else: |
| parser.print_help() |
|
|