| | from typing import Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from transformers.image_processing_utils import BaseImageProcessor |
| |
|
| |
|
| | class TAPCTProcessor(BaseImageProcessor): |
| | """ |
| | Image processor for TAP-CT 3D volumes. |
| | |
| | Processes CT volumes with the following pipeline: |
| | |
| | 1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims |
| | 2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size |
| | 3. Intensity Clipping: Clip to HU range |
| | 4. Normalization: Z-score normalization |
| | |
| | Parameters |
| | ---------- |
| | resize_dims : tuple[int, int], default=(224, 224) |
| | Target spatial dimensions (H, W) for resizing. |
| | divisible_pad_z : int, default=4 |
| | Pad the z-axis to be divisible by this value. |
| | clip_range : tuple[float, float], default=(-1008.0, 822.0) |
| | HU intensity clipping range (min, max). |
| | norm_mean : float, default=-86.80862426757812 |
| | Mean for z-score normalization. |
| | norm_std : float, default=322.63470458984375 |
| | Standard deviation for z-score normalization. |
| | **kwargs |
| | Additional arguments passed to BaseImageProcessor. |
| | """ |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | resize_dims: tuple[int, int] = (224, 224), |
| | divisible_pad_z: int = 4, |
| | clip_range: tuple[float, float] = (-1008.0, 822.0), |
| | norm_mean: float = -86.80862426757812, |
| | norm_std: float = 322.63470458984375, |
| | **kwargs |
| | ) -> None: |
| | super().__init__(**kwargs) |
| | self.resize_dims = resize_dims |
| | self.divisible_pad_z = divisible_pad_z |
| | self.clip_range = clip_range |
| | self.norm_mean = norm_mean |
| | self.norm_std = norm_std |
| |
|
| | def preprocess( |
| | self, |
| | images: Union[torch.Tensor, np.ndarray], |
| | return_tensors: str = "pt", |
| | **kwargs |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Preprocess CT volumes. |
| | |
| | Parameters |
| | ---------- |
| | images : torch.Tensor or np.ndarray |
| | Input tensor or numpy array of shape (B, C, D, H, W) where |
| | B=batch, C=channels, D=depth/slices, H=height, W=width. |
| | return_tensors : str, default="pt" |
| | Return format. Only "pt" (PyTorch) is supported. |
| | **kwargs |
| | Additional keyword arguments (unused). |
| | |
| | Returns |
| | ------- |
| | dict[str, torch.Tensor] |
| | Dictionary with "pixel_values" containing processed tensor of shape |
| | (B, C, D', H', W') where D' may be padded for divisibility. |
| | |
| | Raises |
| | ------ |
| | ValueError |
| | If return_tensors is not "pt" or input is not 5D. |
| | """ |
| | if return_tensors != "pt": |
| | raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}") |
| |
|
| | |
| | if isinstance(images, np.ndarray): |
| | images = torch.from_numpy(images) |
| |
|
| | |
| | images = images.float() |
| |
|
| | |
| | if images.ndim != 5: |
| | raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}") |
| |
|
| | B, C, D, H, W = images.shape |
| |
|
| | |
| | target_h, target_w = self.resize_dims |
| | if H != target_h or W != target_w: |
| | images = self._resize_spatial(images, target_h, target_w) |
| |
|
| | |
| | images = self._pad_axial(images) |
| |
|
| | |
| | images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1]) |
| |
|
| | |
| | images = (images - self.norm_mean) / self.norm_std |
| |
|
| | return {"pixel_values": images} |
| |
|
| | def _resize_spatial( |
| | self, |
| | images: torch.Tensor, |
| | target_h: int, |
| | target_w: int |
| | ) -> torch.Tensor: |
| | """ |
| | Resize spatial dimensions (H, W) using trilinear interpolation. |
| | |
| | Parameters |
| | ---------- |
| | images : torch.Tensor |
| | Tensor of shape (B, C, D, H, W). |
| | target_h : int |
| | Target height. |
| | target_w : int |
| | Target width. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | Resized tensor of shape (B, C, D, target_h, target_w). |
| | """ |
| | D = images.shape[2] |
| |
|
| | |
| | images = F.interpolate( |
| | images, |
| | size=(D, target_h, target_w), |
| | mode='trilinear', |
| | align_corners=False |
| | ) |
| |
|
| | return images |
| |
|
| | def _pad_axial(self, images: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Pad the axial (z/depth) dimension with -1024 HU for divisibility. |
| | |
| | Parameters |
| | ---------- |
| | images : torch.Tensor |
| | Tensor of shape (B, C, D, H, W). |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | Padded tensor of shape (B, C, D', H, W) where D' is divisible |
| | by divisible_pad_z. |
| | """ |
| | D = images.shape[2] |
| | remainder = D % self.divisible_pad_z |
| |
|
| | if remainder == 0: |
| | return images |
| |
|
| | pad_z = self.divisible_pad_z - remainder |
| |
|
| | |
| | |
| | padding = (0, 0, 0, 0, 0, pad_z) |
| | images = F.pad(images, padding, mode='constant', value=-1024.0) |
| |
|
| | return images |
| |
|