CT-FM SegResNet

This model is a SegResNet containing the weights of the pre-trained CT-FM, using contrastive self-supervised learning on a huge dataset of 148,000 CT scans from the Imaging Data Commons.

Running instructions

CT-FM SegResNet Fine-tuning

This notebook demonstrates how to:

  1. Load a SSL pre-trained model into a SegResNet
  2. Recommended preprocessing and postprocessing steps that were used during pre-training
  3. Finetuning instructions overview

Setup

Install requirements and import necessary packages

# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.

# Imports
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground, Invert,
    Activations, AsDiscrete, KeepLargestConnectedComponent,
    SaveImage
)
from monai.inferers import SlidingWindowInferer

Load Model

Download and initialize the pre-trained model from HuggingFace Hub

# Load pre-trained model
model = SegResNet.from_pretrained(
    "project-lighter/ct_fm_segresnet"
)

Setup Processing Pipelines

Define preprocessing and postprocessing transforms

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])
monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.

Run Inference

Process an input CT scan and extract features

# Configure sliding window inference
inferer = SlidingWindowInferer(
    roi_size=[96, 160, 160],  # Size of patches to process
    sw_batch_size=2,          # Number of windows to process in parallel
    overlap=0.625,            # Overlap between windows (reduces boundary artifacts)
    mode="gaussian"           # Gaussian weighting for overlap regions
)

# Input path
input_path = "/home/suraj/Repositories/semantic-search-app/assets/scans/s0114.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    model = model.to("cuda")
    input_tensor = input_tensor.to("cuda")
    output = inferer(input_tensor.unsqueeze(dim=0), model)[0]
    output = output.to("cpu")


print(output.shape)
torch.Size([2, 227, 181, 258])

Fine-tuning Instructions

The model above does not include a trained decoder, which means the predictions you receive will be nonsensical.

However, you can leverage the pre-trained encoder and model architecture to fine-tune on your own datasets—especially if they are small. A simple way to integrate this into your pipeline is to replace the model in your training process with the pre-trained version. For example:

    model = SegResNet.from_pretrained('project-lighter/ct_fm_segresnet')

We recommend using Auto3DSeg in conjunction with our model. For detailed guidance, please refer to the instructions here:
https://project-lighter.github.io/CT-FM/replication-guide/downstream/#tumor-segmentation-with-auto3dseg

Downloads last month
76
Safetensors
Model size
87.2M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.