| from typing import Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers.image_processing_utils import BaseImageProcessor |
|
|
|
|
| class TAPCTProcessor(BaseImageProcessor): |
| """ |
| Image processor for TAP-CT 3D volumes. |
| |
| Processes CT volumes with the following pipeline: |
| |
| 1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims |
| 2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size |
| 3. Intensity Clipping: Clip to HU range |
| 4. Normalization: Z-score normalization |
| |
| Parameters |
| ---------- |
| resize_dims : tuple[int, int], default=(224, 224) |
| Target spatial dimensions (H, W) for resizing. |
| divisible_pad_z : int, default=4 |
| Pad the z-axis to be divisible by this value. |
| clip_range : tuple[float, float], default=(-1008.0, 822.0) |
| HU intensity clipping range (min, max). |
| norm_mean : float, default=-86.80862426757812 |
| Mean for z-score normalization. |
| norm_std : float, default=322.63470458984375 |
| Standard deviation for z-score normalization. |
| **kwargs |
| Additional arguments passed to BaseImageProcessor. |
| """ |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| resize_dims: tuple[int, int] = (224, 224), |
| divisible_pad_z: int = 4, |
| clip_range: tuple[float, float] = (-1008.0, 822.0), |
| norm_mean: float = -86.80862426757812, |
| norm_std: float = 322.63470458984375, |
| **kwargs |
| ) -> None: |
| super().__init__(**kwargs) |
| self.resize_dims = resize_dims |
| self.divisible_pad_z = divisible_pad_z |
| self.clip_range = clip_range |
| self.norm_mean = norm_mean |
| self.norm_std = norm_std |
|
|
| def preprocess( |
| self, |
| images: Union[torch.Tensor, np.ndarray], |
| return_tensors: str = "pt", |
| **kwargs |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Preprocess CT volumes. |
| |
| Parameters |
| ---------- |
| images : torch.Tensor or np.ndarray |
| Input tensor or numpy array of shape (B, C, D, H, W) where |
| B=batch, C=channels, D=depth/slices, H=height, W=width. |
| return_tensors : str, default="pt" |
| Return format. Only "pt" (PyTorch) is supported. |
| **kwargs |
| Additional keyword arguments (unused). |
| |
| Returns |
| ------- |
| dict[str, torch.Tensor] |
| Dictionary with "pixel_values" containing processed tensor of shape |
| (B, C, D', H', W') where D' may be padded for divisibility. |
| |
| Raises |
| ------ |
| ValueError |
| If return_tensors is not "pt" or input is not 5D. |
| """ |
| if return_tensors != "pt": |
| raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}") |
|
|
| |
| if isinstance(images, np.ndarray): |
| images = torch.from_numpy(images) |
|
|
| |
| images = images.float() |
|
|
| |
| if images.ndim != 5: |
| raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}") |
|
|
| B, C, D, H, W = images.shape |
|
|
| |
| target_h, target_w = self.resize_dims |
| if H != target_h or W != target_w: |
| images = self._resize_spatial(images, target_h, target_w) |
|
|
| |
| images = self._pad_axial(images) |
|
|
| |
| images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1]) |
|
|
| |
| images = (images - self.norm_mean) / self.norm_std |
|
|
| return {"pixel_values": images} |
|
|
| def _resize_spatial( |
| self, |
| images: torch.Tensor, |
| target_h: int, |
| target_w: int |
| ) -> torch.Tensor: |
| """ |
| Resize spatial dimensions (H, W) using trilinear interpolation. |
| |
| Parameters |
| ---------- |
| images : torch.Tensor |
| Tensor of shape (B, C, D, H, W). |
| target_h : int |
| Target height. |
| target_w : int |
| Target width. |
| |
| Returns |
| ------- |
| torch.Tensor |
| Resized tensor of shape (B, C, D, target_h, target_w). |
| """ |
| D = images.shape[2] |
|
|
| |
| images = F.interpolate( |
| images, |
| size=(D, target_h, target_w), |
| mode='trilinear', |
| align_corners=False |
| ) |
|
|
| return images |
|
|
| def _pad_axial(self, images: torch.Tensor) -> torch.Tensor: |
| """ |
| Pad the axial (z/depth) dimension with -1024 HU for divisibility. |
| |
| Parameters |
| ---------- |
| images : torch.Tensor |
| Tensor of shape (B, C, D, H, W). |
| |
| Returns |
| ------- |
| torch.Tensor |
| Padded tensor of shape (B, C, D', H, W) where D' is divisible |
| by divisible_pad_z. |
| """ |
| D = images.shape[2] |
| remainder = D % self.divisible_pad_z |
|
|
| if remainder == 0: |
| return images |
|
|
| pad_z = self.divisible_pad_z - remainder |
|
|
| |
| |
| padding = (0, 0, 0, 0, 0, pad_z) |
| images = F.pad(images, padding, mode='constant', value=-1024.0) |
|
|
| return images |
|
|