|
|
""" |
|
|
Utility functions for NeuroSAM 3 application. |
|
|
Helper functions for image processing, visualization, and common operations. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Tuple, List, Dict, Any |
|
|
import os |
|
|
import re |
|
|
import tempfile |
|
|
import numpy as np |
|
|
import pydicom |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
from logger_config import logger |
|
|
|
|
|
|
|
|
def extract_subject_id(file_path: str) -> Tuple[str, str, str]: |
|
|
""" |
|
|
Extract subject/patient ID from file path. |
|
|
|
|
|
Common patterns: |
|
|
- Folder name: /subject_001/image.png -> subject_001 |
|
|
- Filename prefix: subject_001_slice_01.png -> subject_001 |
|
|
- Patient ID in filename: patient_123_slice_5.dcm -> patient_123 |
|
|
- Study UID in DICOM: extract from DICOM metadata |
|
|
|
|
|
Args: |
|
|
file_path: Path to file |
|
|
|
|
|
Returns: |
|
|
Tuple of (subject_id, confidence_level, source) |
|
|
confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback) |
|
|
source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback' |
|
|
""" |
|
|
file_path = str(file_path) |
|
|
filename = os.path.basename(file_path) |
|
|
dir_path = os.path.dirname(file_path) |
|
|
|
|
|
|
|
|
if file_path.lower().endswith('.dcm'): |
|
|
try: |
|
|
ds = pydicom.dcmread(file_path, stop_before_pixels=True) |
|
|
patient_id = getattr(ds, 'PatientID', None) |
|
|
if patient_id and patient_id.strip(): |
|
|
return f"patient_{patient_id}", 'high', 'dicom_patientid' |
|
|
|
|
|
study_uid = getattr(ds, 'StudyInstanceUID', None) |
|
|
if study_uid: |
|
|
|
|
|
return f"study_{study_uid}", 'high', 'dicom_study' |
|
|
except Exception as e: |
|
|
logger.debug(f"Could not read DICOM metadata: {e}") |
|
|
|
|
|
|
|
|
folder_name = os.path.basename(dir_path.rstrip('/')) |
|
|
if folder_name and folder_name not in ['', '.', '..']: |
|
|
|
|
|
if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I): |
|
|
return folder_name, 'medium', 'folder' |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
(r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'), |
|
|
(r'([A-Z]{2,}\d+)', 'medium'), |
|
|
] |
|
|
|
|
|
for pattern, confidence in patterns: |
|
|
match = re.search(pattern, filename, re.I) |
|
|
if match: |
|
|
if len(match.groups()) > 1: |
|
|
return f"{match.group(1)}_{match.group(2)}", confidence, 'filename' |
|
|
else: |
|
|
return match.group(1), confidence, 'filename' |
|
|
|
|
|
|
|
|
numeric_match = re.search(r'(\d{3,})', filename) |
|
|
if numeric_match: |
|
|
return numeric_match.group(1), 'low', 'filename_numeric' |
|
|
|
|
|
|
|
|
base_name = os.path.splitext(filename)[0] |
|
|
if len(base_name) > 0: |
|
|
return base_name, 'low', 'fallback' |
|
|
|
|
|
return "unknown", 'low', 'unknown' |
|
|
|
|
|
|
|
|
def group_images_by_subject(image_files: List[str]) -> Dict[str, Dict[str, Any]]: |
|
|
""" |
|
|
Group image files by subject/patient ID. |
|
|
|
|
|
Args: |
|
|
image_files: List of file paths |
|
|
|
|
|
Returns: |
|
|
Dictionary: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}} |
|
|
""" |
|
|
if not image_files: |
|
|
return {} |
|
|
|
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
|
|
|
|
|
|
image_files = [f for f in image_files if f is not None] |
|
|
|
|
|
|
|
|
subject_groups = {} |
|
|
for file_path in image_files: |
|
|
subject_id, confidence, source = extract_subject_id(file_path) |
|
|
|
|
|
if subject_id not in subject_groups: |
|
|
subject_groups[subject_id] = { |
|
|
'files': [], |
|
|
'confidence': confidence, |
|
|
'sources': set([source]) |
|
|
} |
|
|
|
|
|
subject_groups[subject_id]['files'].append(file_path) |
|
|
subject_groups[subject_id]['sources'].add(source) |
|
|
|
|
|
|
|
|
if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'): |
|
|
subject_groups[subject_id]['confidence'] = confidence |
|
|
|
|
|
|
|
|
for subject_id in subject_groups: |
|
|
subject_groups[subject_id]['files'].sort() |
|
|
subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources']) |
|
|
|
|
|
return subject_groups |
|
|
|
|
|
|
|
|
def combine_masks(masks) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Combine multiple mask arrays into a single mask. |
|
|
|
|
|
Args: |
|
|
masks: List of mask arrays, or numpy array, or None |
|
|
|
|
|
Returns: |
|
|
Combined mask array or None if no valid masks |
|
|
""" |
|
|
if masks is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(masks, np.ndarray): |
|
|
if masks.ndim == 0: |
|
|
return None |
|
|
elif masks.ndim == 1: |
|
|
if len(masks) == 0: |
|
|
return None |
|
|
masks = [masks] |
|
|
else: |
|
|
return masks |
|
|
|
|
|
|
|
|
if not isinstance(masks, (list, tuple)): |
|
|
|
|
|
try: |
|
|
masks = list(masks) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
if len(masks) == 0: |
|
|
return None |
|
|
|
|
|
mask_arrays = [] |
|
|
for mask in masks: |
|
|
if isinstance(mask, np.ndarray): |
|
|
mask_arrays.append(mask) |
|
|
else: |
|
|
|
|
|
try: |
|
|
mask_np = np.array(mask) |
|
|
if mask_np.size > 0: |
|
|
mask_arrays.append(mask_np) |
|
|
except Exception as e: |
|
|
logger.debug(f"Could not convert mask to numpy: {e}") |
|
|
continue |
|
|
|
|
|
if len(mask_arrays) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
mask_arrays_2d = [] |
|
|
for mask in mask_arrays: |
|
|
if mask.ndim == 3: |
|
|
|
|
|
if mask.shape[0] == 3 or mask.shape[2] == 3: |
|
|
if mask.shape[0] == 3: |
|
|
mask = np.mean(mask, axis=0) > 0.5 |
|
|
else: |
|
|
mask = np.mean(mask, axis=2) > 0.5 |
|
|
else: |
|
|
mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0] |
|
|
elif mask.ndim > 3: |
|
|
mask = mask.squeeze() |
|
|
if mask.ndim != 2: |
|
|
mask = mask.reshape(mask.shape[-2], mask.shape[-1]) |
|
|
|
|
|
|
|
|
if mask.dtype != bool: |
|
|
mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2) |
|
|
|
|
|
mask_arrays_2d.append(mask) |
|
|
|
|
|
|
|
|
if len(mask_arrays_2d) > 1: |
|
|
target_shape = mask_arrays_2d[0].shape |
|
|
for i in range(1, len(mask_arrays_2d)): |
|
|
if mask_arrays_2d[i].shape != target_shape: |
|
|
from scipy.ndimage import zoom |
|
|
zoom_factors = ( |
|
|
target_shape[0] / mask_arrays_2d[i].shape[0], |
|
|
target_shape[1] / mask_arrays_2d[i].shape[1] |
|
|
) |
|
|
mask_arrays_2d[i] = zoom(mask_arrays_2d[i].astype(float), zoom_factors, order=0) > 0.5 |
|
|
|
|
|
combined_mask = np.any(mask_arrays_2d, axis=0) |
|
|
return combined_mask.astype(bool) |
|
|
except Exception as e: |
|
|
logger.error(f"Error combining masks: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
|
|
|
def create_output_image( |
|
|
pil_image: Image.Image, |
|
|
mask: Optional[np.ndarray], |
|
|
prompt_text: str, |
|
|
colormap: str = 'spring', |
|
|
transparency: float = 0.5, |
|
|
title: Optional[str] = None |
|
|
) -> str: |
|
|
""" |
|
|
Create output visualization image with mask overlay. |
|
|
|
|
|
Args: |
|
|
pil_image: Base PIL image |
|
|
mask: Optional mask array to overlay (2D or 3D) |
|
|
prompt_text: Prompt text for title |
|
|
colormap: Matplotlib colormap name |
|
|
transparency: Mask transparency (0.0-1.0) |
|
|
title: Optional custom title |
|
|
|
|
|
Returns: |
|
|
Path to saved output image |
|
|
""" |
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
if isinstance(mask, np.ndarray): |
|
|
if mask.ndim == 3: |
|
|
|
|
|
if mask.shape[0] == 3 or mask.shape[2] == 3: |
|
|
|
|
|
if mask.shape[0] == 3: |
|
|
|
|
|
mask = np.mean(mask, axis=0) |
|
|
else: |
|
|
|
|
|
mask = np.mean(mask, axis=2) |
|
|
else: |
|
|
|
|
|
mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0] |
|
|
elif mask.ndim > 3: |
|
|
|
|
|
mask = mask.squeeze() |
|
|
if mask.ndim != 2: |
|
|
logger.warning(f"Mask has {mask.ndim} dimensions, expected 2D. Flattening...") |
|
|
mask = mask.reshape(mask.shape[-2], mask.shape[-1]) |
|
|
|
|
|
|
|
|
if mask.dtype != bool: |
|
|
|
|
|
mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2) |
|
|
|
|
|
plt.imshow(mask, alpha=transparency, cmap=colormap) |
|
|
|
|
|
plt.axis('off') |
|
|
display_title = title or f"Segmentation: {prompt_text}" |
|
|
plt.title(display_title, fontsize=12, pad=10) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
from config import OUTPUT_DPI |
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=OUTPUT_DPI) |
|
|
plt.close() |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
def create_demo_dicom_file(output_path: str = "demo_brain_mri.dcm") -> bool: |
|
|
""" |
|
|
Create a demo DICOM file for testing. |
|
|
|
|
|
Args: |
|
|
output_path: Path where to save the demo file |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
from pydicom.data import get_testdata_file |
|
|
test_file = get_testdata_file("MR_small.dcm") |
|
|
if test_file and os.path.exists(test_file): |
|
|
import shutil |
|
|
shutil.copy(test_file, output_path) |
|
|
logger.info(f"Demo file ready: {output_path}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.debug(f"Could not copy test DICOM file: {e}") |
|
|
|
|
|
try: |
|
|
|
|
|
from pydicom.dataset import FileDataset, FileMetaDataset |
|
|
from pydicom.uid import generate_uid |
|
|
|
|
|
synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16) |
|
|
center_x, center_y = 128, 128 |
|
|
y, x = np.ogrid[:256, :256] |
|
|
mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2 |
|
|
synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255) |
|
|
|
|
|
file_meta = FileMetaDataset() |
|
|
file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4' |
|
|
file_meta.MediaStorageSOPInstanceUID = generate_uid() |
|
|
file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1' |
|
|
|
|
|
ds = FileDataset(output_path, {}, file_meta=file_meta, preamble=b"\x00" * 128) |
|
|
ds.PatientName = "Demo^Patient" |
|
|
ds.PatientID = "DEMO001" |
|
|
ds.Modality = "MR" |
|
|
ds.Rows = 256 |
|
|
ds.Columns = 256 |
|
|
ds.BitsAllocated = 16 |
|
|
ds.BitsStored = 16 |
|
|
ds.HighBit = 15 |
|
|
ds.SamplesPerPixel = 1 |
|
|
ds.PixelRepresentation = 0 |
|
|
ds.PhotometricInterpretation = "MONOCHROME2" |
|
|
ds.PixelSpacing = [1.0, 1.0] |
|
|
ds.RescaleIntercept = "0" |
|
|
ds.RescaleSlope = "1" |
|
|
ds.PixelData = synthetic_image.tobytes() |
|
|
|
|
|
ds.save_as(output_path, write_like_original=False) |
|
|
logger.info(f"Synthetic demo file created: {output_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Could not create demo file: {e}") |
|
|
return False |
|
|
|
|
|
|