| """Image processor for Sybil CT scan preprocessing""" |
|
|
| import numpy as np |
| import torch |
| from typing import Dict, List, Optional, Union, Tuple |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers.utils import TensorType |
| import pydicom |
| from PIL import Image |
| import torchio as tio |
|
|
|
|
| def order_slices(dicoms: List) -> List: |
| """Order DICOM slices by their position""" |
| |
| try: |
| dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) |
| except (AttributeError, TypeError): |
| |
| try: |
| dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) |
| except (AttributeError, TypeError): |
| pass |
| return dicoms |
|
|
|
|
| class SybilImageProcessor(BaseImageProcessor): |
| """ |
| Constructs a Sybil image processor for preprocessing CT scans. |
| |
| Args: |
| voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): |
| Target voxel spacing for resampling (row, column, slice thickness). |
| img_size (`List[int]`, *optional*, defaults to `[512, 512]`): |
| Target image size after resizing. |
| num_images (`int`, *optional*, defaults to `208`): |
| Number of slices to use from the CT scan. |
| windowing (`Dict[str, float]`, *optional*): |
| Windowing parameters for CT scan visualization. |
| Default uses lung window: center=-600, width=1500. |
| normalize (`bool`, *optional*, defaults to `True`): |
| Whether to normalize pixel values to [0, 1]. |
| **kwargs: |
| Additional keyword arguments passed to the parent class. |
| """ |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| voxel_spacing: List[float] = None, |
| img_size: List[int] = None, |
| num_images: int = 208, |
| windowing: Dict[str, float] = None, |
| normalize: bool = True, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
|
|
| self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] |
| self.img_size = img_size if img_size is not None else [512, 512] |
| self.num_images = num_images |
|
|
| |
| self.windowing = windowing if windowing is not None else { |
| "center": -600, |
| "width": 1500 |
| } |
| self.normalize = normalize |
|
|
| |
| self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) |
| |
| self.default_depth = 200 |
| self.default_size = [256, 256] |
| self.padding_transform = tio.transforms.CropOrPad( |
| target_shape=(self.default_depth, *self.default_size), |
| padding_mode=0 |
| ) |
|
|
| def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: |
| """ |
| Load a series of DICOM files. |
| |
| Args: |
| paths: List of paths to DICOM files. |
| |
| Returns: |
| Tuple of (volume array, metadata dict) |
| """ |
| dicoms = [] |
| for path in paths: |
| try: |
| dcm = pydicom.dcmread(path, stop_before_pixels=False) |
| dicoms.append(dcm) |
| except Exception as e: |
| print(f"Error reading DICOM file {path}: {e}") |
| continue |
|
|
| if not dicoms: |
| raise ValueError("No valid DICOM files found") |
|
|
| |
| dicoms = order_slices(dicoms) |
|
|
| |
| volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) |
|
|
| |
| metadata = { |
| "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, |
| "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, |
| "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, |
| "num_slices": len(dicoms) |
| } |
|
|
| |
| if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): |
| slope = float(dicoms[0].RescaleSlope) |
| intercept = float(dicoms[0].RescaleIntercept) |
| volume = volume * slope + intercept |
|
|
| return volume, metadata |
|
|
| def load_png_series(self, paths: List[str]) -> np.ndarray: |
| """ |
| Load a series of PNG files. |
| |
| Args: |
| paths: List of paths to PNG files (must be in anatomical order). |
| |
| Returns: |
| 3D volume array |
| """ |
| images = [] |
| for path in paths: |
| img = Image.open(path).convert('L') |
| images.append(np.array(img, dtype=np.float32)) |
|
|
| return np.stack(images) |
|
|
| def apply_windowing(self, volume: np.ndarray) -> np.ndarray: |
| """ |
| Apply windowing to CT scan for better visualization. |
| |
| Args: |
| volume: 3D CT volume. |
| |
| Returns: |
| Windowed volume. |
| """ |
| center = self.windowing["center"] |
| width = self.windowing["width"] |
|
|
| |
| lower = center - width / 2 |
| upper = center + width / 2 |
|
|
| |
| volume = np.clip(volume, lower, upper) |
|
|
| |
| if self.normalize: |
| volume = (volume - lower) / (upper - lower) |
|
|
| return volume |
|
|
| def resample_volume( |
| self, |
| volume: torch.Tensor, |
| original_spacing: Optional[List[float]] = None |
| ) -> torch.Tensor: |
| """ |
| Resample volume to target voxel spacing. |
| |
| Args: |
| volume: 3D volume tensor. |
| original_spacing: Original voxel spacing. |
| |
| Returns: |
| Resampled volume. |
| """ |
| |
| subject = tio.Subject( |
| image=tio.ScalarImage(tensor=volume.unsqueeze(0), spacing=original_spacing) |
| ) |
|
|
| |
| resampled = self.resample_transform(subject) |
|
|
| return resampled['image'].data.squeeze(0) |
|
|
| def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: |
| """ |
| Pad or crop volume to target shape. |
| |
| Args: |
| volume: 3D volume tensor. |
| |
| Returns: |
| Padded/cropped volume. |
| """ |
| |
| subject = tio.Subject( |
| image=tio.ScalarImage(tensor=volume.unsqueeze(0)) |
| ) |
|
|
| |
| transformed = self.padding_transform(subject) |
|
|
| return transformed['image'].data.squeeze(0) |
|
|
| def preprocess( |
| self, |
| images: Union[List[str], np.ndarray, torch.Tensor], |
| file_type: str = "dicom", |
| voxel_spacing: Optional[List[float]] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| **kwargs |
| ) -> BatchFeature: |
| """ |
| Preprocess CT scan images. |
| |
| Args: |
| images: Either list of file paths or numpy/torch array of images. |
| file_type: Type of input files ("dicom" or "png"). |
| voxel_spacing: Original voxel spacing (required for PNG files). |
| return_tensors: The type of tensors to return. |
| |
| Returns: |
| BatchFeature with preprocessed images. |
| """ |
| |
| if isinstance(images, list) and isinstance(images[0], str): |
| if file_type == "dicom": |
| volume, metadata = self.load_dicom_series(images) |
| if voxel_spacing is None and metadata["pixel_spacing"]: |
| voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] |
| elif file_type == "png": |
| if voxel_spacing is None: |
| raise ValueError("voxel_spacing must be provided for PNG files") |
| volume = self.load_png_series(images) |
| else: |
| raise ValueError(f"Unknown file type: {file_type}") |
| elif isinstance(images, (np.ndarray, torch.Tensor)): |
| volume = images |
| else: |
| raise ValueError("Images must be file paths, numpy array, or torch tensor") |
|
|
| |
| if isinstance(volume, np.ndarray): |
| volume = torch.from_numpy(volume).float() |
|
|
| |
| if isinstance(volume, torch.Tensor): |
| volume_np = volume.numpy() |
| else: |
| volume_np = volume |
| volume_np = self.apply_windowing(volume_np) |
| volume = torch.from_numpy(volume_np).float() |
|
|
| |
| if voxel_spacing is not None: |
| volume = self.resample_volume(volume, voxel_spacing) |
|
|
| |
| volume = self.pad_or_crop_volume(volume) |
|
|
| |
| |
| volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) |
|
|
| |
| data = {"pixel_values": volume} |
|
|
| |
| if return_tensors == "pt": |
| return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) |
| elif return_tensors == "np": |
| data = {k: v.numpy() for k, v in data.items()} |
| return BatchFeature(data=data, tensor_type=TensorType.NUMPY) |
| else: |
| return BatchFeature(data=data) |
|
|
| def __call__( |
| self, |
| images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], |
| **kwargs |
| ) -> BatchFeature: |
| """ |
| Main method to prepare images for the model. |
| |
| Args: |
| images: Images to preprocess. Can be: |
| - List of file paths for a single series |
| - List of lists of file paths for multiple series |
| - Numpy array or torch tensor |
| |
| Returns: |
| BatchFeature with preprocessed images ready for model input. |
| """ |
| |
| if isinstance(images, list) and images and isinstance(images[0], list): |
| |
| batch_volumes = [] |
| for series_paths in images: |
| result = self.preprocess(series_paths, **kwargs) |
| batch_volumes.append(result["pixel_values"]) |
|
|
| |
| pixel_values = torch.stack(batch_volumes) |
| return BatchFeature(data={"pixel_values": pixel_values}) |
| else: |
| |
| return self.preprocess(images, **kwargs) |