DS6 / app.py
soumickmj's picture
trial version ready
9454e59
raw
history blame
10.1 kB
import streamlit as st
import math
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
from transformers import AutoModel
import os
import tempfile
from pathlib import Path
from skimage.filters import threshold_otsu
import torchio as tio
# import psutil
def infer_full_vol(tensor, model):
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims
tensor = torch.movedim(tensor, -1, -3)
tensor = tensor / tensor.max()
sizes = tensor.shape[-3:]
new_sizes = [math.ceil(s / 16) * 16 for s in sizes]
total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)]
pad_before = [pad // 2 for pad in total_pads]
pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)]
padding = []
for i in reversed(range(len(pad_before))):
padding.extend([pad_before[i], pad_after[i]])
tensor = F.pad(tensor, padding)
with torch.no_grad():
output = model(tensor)
if type(output) is tuple or type(output) is list:
output = output[0]
output = torch.sigmoid(output)
slices = [slice(None)] * output.dim()
for i in range(len(pad_before)):
dim = -3 + i
start = pad_before[i]
size = sizes[i]
end = start + size
slices[dim] = slice(start, end)
output = output[tuple(slices)]
output = torch.movedim(output, -3, -1).type(tensor.type())
return output.squeeze().detach().cpu().numpy()
def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject
overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
with torch.no_grad():
grid_sampler = tio.inference.GridSampler(
test_subject,
patch_size,
overlap,
)
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average")
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker)
total_batches = len(patch_loader)
progress_bar = st.progress(0)
for i, patches_batch in enumerate(patch_loader):
st.text(f"Processing batch {i + 1} of {total_batches}...")
local_batch = patches_batch['img'][tio.DATA].float()
local_batch = local_batch / local_batch.max()
locations = patches_batch[tio.LOCATION]
local_batch = torch.movedim(local_batch, -1, -3)
output = model(local_batch)
if type(output) is tuple or type(output) is list:
output = output[0]
output = torch.sigmoid(output).detach().cpu()
output = torch.movedim(output, -3, -1).type(local_batch.type())
aggregator.add_batch(output, locations)
progress_bar.progress((i + 1) / total_batches)
# st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%")
predicted = aggregator.get_output_tensor().squeeze().numpy()
return predicted
# Set page configuration
st.set_page_config(
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded",
)
# Sidebar content
with st.sidebar:
st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6")
st.markdown("""
This application allows you to upload a 3D NIfTI file (dims: H x W x D, where the final dim is the slice dim in the axial plane), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation.
**Instructions**:
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
- Select a model from the dropdown menu.
- Click the "Process" button to generate the latent factors.
""")
st.markdown("---")
st.markdown("© 2024 Soumick Chatterjee")
# Main content
st.header("DS6, Deformation-Aware Semi-Supervised Learning: Application to Small Vessel Segmentation with Noisy Training Data")
# File uploader
uploaded_file = st.file_uploader(
"Please upload a 3D NIfTI file (.nii or .nii.gz)",
type=["nii", "nii.gz"]
)
# Model selection
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
selected_model = st.selectbox("Select a pretrained model:", model_options)
# Mode selection
mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
selected_mode = st.selectbox("Select the running mode:", mode_options)
# Parameters for patch-based inference
if selected_mode == "Patch-based inference [Default for DS6]":
col1, col2, col3 = st.columns(3)
with col1:
patch_size = st.number_input("Patch size:", min_value=1, value=64)
stride_length = st.number_input("Stride length:", min_value=1, value=32)
with col2:
batch_size = st.number_input("Batch size:", min_value=1, value=14)
stride_width = st.number_input("Stride width:", min_value=1, value=32)
with col3:
num_worker = st.number_input("Number of workers:", min_value=1, value=3)
stride_depth = st.number_input("Stride depth:", min_value=1, value=16)
# Process button
process_button = st.button("Process")
if uploaded_file is not None and process_button:
try:
# Save the uploaded file to a temporary file
file_extension = ''.join(Path(uploaded_file.name).suffixes)
with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file.flush()
# Load the NIfTI file from the temporary file
nifti_img = nib.load(tmp_file.name)
data = nifti_img.get_fdata()
# Convert to PyTorch tensor
tensor = torch.from_numpy(data).float()
# Ensure it's 3D
if tensor.ndim != 3:
st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.")
else:
# Display input details
st.success("File successfully uploaded and read.")
st.write(f"Input tensor shape: `{tensor.shape}`")
st.write(f"Selected pretrained model: `{selected_model}`")
# Construct the model name based on the selected model
model_name = f"soumickmj/{selected_model}"
# Load the pre-trained model from Hugging Face
@st.cache_resource
def load_model(model_name):
hf_token = os.environ.get('HF_API_TOKEN')
if hf_token is None:
st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.")
return None
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
use_auth_token=hf_token
)
model.eval()
return model
except Exception as e:
st.error(f"Failed to load model: {e}")
return None
with st.spinner('Loading the pre-trained model...'):
model = load_model(model_name)
if model is None:
st.stop() # Stop the app if the model couldn't be loaded
# Move model and tensor to CPU (ensure compatibility with Spaces)
device = torch.device('cpu')
model = model.to(device)
tensor = tensor.to(device)
# Process the tensor through the model
with st.spinner('Processing the tensor through the model...'):
if selected_mode == "Full volume inference":
st.info("Running full volume inference...")
output = infer_full_vol(tensor, model)
else:
st.info("Running patch-based inference [Default for DS6]...")
output = infer_patch_based(tensor, model, patch_size=patch_size, stride_length=stride_length, stride_width=stride_width, stride_depth=stride_depth, batch_size=batch_size, num_worker=num_worker)
st.success("Processing complete.")
st.write(f"Output tensor shape: `{output.shape}`")
try:
thresh = threshold_otsu(output)
output = output > thresh
except Exception as error:
st.error(f"Otsu thresholding failed: {error}. Defaulting to a threshold of 0.5.")
output = output > 0.5 # exception only if input image seems to have just one color 1.0.
output = output.astype('uint16')
# Save the output as a NIfTI file
output_img = nib.Nifti1Image(output, affine=nifti_img.affine)
output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name
nib.save(output_img, output_path)
# Read the saved file for download
with open(output_path, "rb") as f:
output_data = f.read()
# Download button for NIfTI file
st.download_button(
label="Download Segmentation Output",
data=output_data,
file_name='segmentation_output.nii.gz',
mime='application/gzip'
)
except Exception as e:
st.error(f"An error occurred: {e}")
elif uploaded_file is None:
st.info("Awaiting file upload...")
elif not process_button:
st.info("Click the 'Process' button to start processing.")