TILA / processor.py
lukeingawesome's picture
Upload folder using huggingface_hub
f46fb4d verified
"""
TILA β€” Image Processor
Single processor that handles the full pipeline:
raw image (path, numpy, or PIL) β†’ model-ready tensor [1, 3, 448, 448]
Combines:
1. Medical image preprocessing (windowing, padding removal, resize)
2. Model transforms (resize, center crop, to tensor, expand channels)
Usage:
from processor import TILAProcessor
processor = TILAProcessor()
# From file path (applies full preprocessing)
tensor = processor("raw_cxr.png")
# From PIL image (skips medical preprocessing, applies model transforms only)
tensor = processor(Image.open("preprocessed.png"))
# Pair of images for the model
current = processor("current.png")
previous = processor("previous.png")
result = model.get_interval_change_prediction(current, previous)
"""
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from typing import Union
from preprocess import preprocess_image
class TILAProcessor:
"""End-to-end image processor for the TILA model.
Accepts file paths (str/Path), numpy arrays, or PIL Images.
- File paths: full pipeline (windowing β†’ crop β†’ resize β†’ model transform)
- Numpy arrays: treated as raw, full pipeline applied
- PIL Images: assumed already preprocessed, only model transforms applied
Args:
raw_preprocess: Apply medical preprocessing (windowing, padding removal).
Set False if images are already preprocessed PNGs.
width_param: Windowing width parameter (default: 4.0)
max_size: Resize longest side to this before model transforms (default: 512)
crop_size: Center crop size for model input (default: 448)
dtype: Output tensor dtype (default: torch.bfloat16)
device: Output tensor device (default: "cpu")
"""
def __init__(
self,
raw_preprocess: bool = True,
width_param: float = 4.0,
max_size: int = 512,
crop_size: int = 448,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
):
self.raw_preprocess = raw_preprocess
self.width_param = width_param
self.max_size = max_size
self.dtype = dtype
self.device = device
self.model_transform = transforms.Compose([
transforms.Resize(max_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
_ExpandChannels(),
])
def __call__(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor:
"""Process a single image into a model-ready tensor.
Args:
image: File path (str), numpy array, or PIL Image
Returns:
Tensor of shape [1, 3, 448, 448]
"""
if isinstance(image, str):
if self.raw_preprocess:
img_np = preprocess_image(image, self.width_param, self.max_size)
pil_img = Image.fromarray(img_np)
else:
pil_img = Image.open(image).convert("L")
elif isinstance(image, np.ndarray):
if self.raw_preprocess:
from preprocess import apply_windowing, remove_black_padding, resize_preserve_aspect_ratio
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = apply_windowing(image, self.width_param)
image = (image * 255.0).astype(np.uint8)
image = remove_black_padding(image)
image = resize_preserve_aspect_ratio(image, self.max_size)
pil_img = Image.fromarray(image)
elif isinstance(image, Image.Image):
pil_img = image.convert("L")
else:
raise TypeError(f"Expected str, np.ndarray, or PIL.Image, got {type(image)}")
tensor = self.model_transform(pil_img).unsqueeze(0)
return tensor.to(dtype=self.dtype, device=self.device)
class _ExpandChannels:
"""Expand single-channel tensor to 3 channels."""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[0] == 1:
return x.repeat(3, 1, 1)
return x