Spaces:
Runtime error
Runtime error
import streamlit as st | |
import math | |
import numpy as np | |
import nibabel as nib | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModel | |
import os | |
import tempfile | |
from pathlib import Path | |
from skimage.filters import threshold_otsu | |
import torchio as tio | |
# import psutil | |
def infer_full_vol(tensor, model): | |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims | |
tensor = torch.movedim(tensor, -1, -3) | |
tensor = tensor / tensor.max() | |
sizes = tensor.shape[-3:] | |
new_sizes = [math.ceil(s / 16) * 16 for s in sizes] | |
total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)] | |
pad_before = [pad // 2 for pad in total_pads] | |
pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)] | |
padding = [] | |
for i in reversed(range(len(pad_before))): | |
padding.extend([pad_before[i], pad_after[i]]) | |
tensor = F.pad(tensor, padding) | |
with torch.no_grad(): | |
output = model(tensor) | |
if type(output) is tuple or type(output) is list: | |
output = output[0] | |
output = torch.sigmoid(output) | |
slices = [slice(None)] * output.dim() | |
for i in range(len(pad_before)): | |
dim = -3 + i | |
start = pad_before[i] | |
size = sizes[i] | |
end = start + size | |
slices[dim] = slice(start, end) | |
output = output[tuple(slices)] | |
output = torch.movedim(output, -3, -1).type(tensor.type()) | |
return output.squeeze().detach().cpu().numpy() | |
def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2): | |
test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject | |
overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth)) | |
with torch.no_grad(): | |
grid_sampler = tio.inference.GridSampler( | |
test_subject, | |
patch_size, | |
overlap, | |
) | |
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode="average") | |
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=batch_size, shuffle=False, num_workers=num_worker) | |
total_batches = len(patch_loader) | |
progress_bar = st.progress(0) | |
for i, patches_batch in enumerate(patch_loader): | |
st.text(f"Processing batch {i + 1} of {total_batches}...") | |
local_batch = patches_batch['img'][tio.DATA].float() | |
local_batch = local_batch / local_batch.max() | |
locations = patches_batch[tio.LOCATION] | |
local_batch = torch.movedim(local_batch, -1, -3) | |
output = model(local_batch) | |
if type(output) is tuple or type(output) is list: | |
output = output[0] | |
output = torch.sigmoid(output).detach().cpu() | |
output = torch.movedim(output, -3, -1).type(local_batch.type()) | |
aggregator.add_batch(output, locations) | |
progress_bar.progress((i + 1) / total_batches) | |
# st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%") | |
predicted = aggregator.get_output_tensor().squeeze().numpy() | |
return predicted | |
# Set page configuration | |
st.set_page_config( | |
page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)", | |
page_icon="🧠", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# Sidebar content | |
with st.sidebar: | |
st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6") | |
st.markdown(""" | |
This application allows you to upload a 3D NIfTI file (dims: H x W x D, where the final dim is the slice dim in the axial plane), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation. | |
**Instructions**: | |
- Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time. | |
- Select a model from the dropdown menu. | |
- Click the "Process" button to generate the latent factors. | |
""") | |
st.markdown("---") | |
st.markdown("© 2024 Soumick Chatterjee") | |
# Main content | |
st.header("DS6, Deformation-Aware Semi-Supervised Learning: Application to Small Vessel Segmentation with Noisy Training Data") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Please upload a 3D NIfTI file (.nii or .nii.gz)", | |
type=["nii", "nii.gz"] | |
) | |
# Model selection | |
model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"] | |
selected_model = st.selectbox("Select a pretrained model:", model_options) | |
# Mode selection | |
mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"] | |
selected_mode = st.selectbox("Select the running mode:", mode_options) | |
# Parameters for patch-based inference | |
if selected_mode == "Patch-based inference [Default for DS6]": | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
patch_size = st.number_input("Patch size:", min_value=1, value=64) | |
stride_length = st.number_input("Stride length:", min_value=1, value=32) | |
with col2: | |
batch_size = st.number_input("Batch size:", min_value=1, value=14) | |
stride_width = st.number_input("Stride width:", min_value=1, value=32) | |
with col3: | |
num_worker = st.number_input("Number of workers:", min_value=1, value=3) | |
stride_depth = st.number_input("Stride depth:", min_value=1, value=16) | |
# Process button | |
process_button = st.button("Process") | |
if uploaded_file is not None and process_button: | |
try: | |
# Save the uploaded file to a temporary file | |
file_extension = ''.join(Path(uploaded_file.name).suffixes) | |
with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file.flush() | |
# Load the NIfTI file from the temporary file | |
nifti_img = nib.load(tmp_file.name) | |
data = nifti_img.get_fdata() | |
# Convert to PyTorch tensor | |
tensor = torch.from_numpy(data).float() | |
# Ensure it's 3D | |
if tensor.ndim != 3: | |
st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.") | |
else: | |
# Display input details | |
st.success("File successfully uploaded and read.") | |
st.write(f"Input tensor shape: `{tensor.shape}`") | |
st.write(f"Selected pretrained model: `{selected_model}`") | |
# Construct the model name based on the selected model | |
model_name = f"soumickmj/{selected_model}" | |
# Load the pre-trained model from Hugging Face | |
def load_model(model_name): | |
hf_token = os.environ.get('HF_API_TOKEN') | |
if hf_token is None: | |
st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.") | |
return None | |
try: | |
model = AutoModel.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=hf_token | |
) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None | |
with st.spinner('Loading the pre-trained model...'): | |
model = load_model(model_name) | |
if model is None: | |
st.stop() # Stop the app if the model couldn't be loaded | |
# Move model and tensor to CPU (ensure compatibility with Spaces) | |
device = torch.device('cpu') | |
model = model.to(device) | |
tensor = tensor.to(device) | |
# Process the tensor through the model | |
with st.spinner('Processing the tensor through the model...'): | |
if selected_mode == "Full volume inference": | |
st.info("Running full volume inference...") | |
output = infer_full_vol(tensor, model) | |
else: | |
st.info("Running patch-based inference [Default for DS6]...") | |
output = infer_patch_based(tensor, model, patch_size=patch_size, stride_length=stride_length, stride_width=stride_width, stride_depth=stride_depth, batch_size=batch_size, num_worker=num_worker) | |
st.success("Processing complete.") | |
st.write(f"Output tensor shape: `{output.shape}`") | |
try: | |
thresh = threshold_otsu(output) | |
output = output > thresh | |
except Exception as error: | |
st.error(f"Otsu thresholding failed: {error}. Defaulting to a threshold of 0.5.") | |
output = output > 0.5 # exception only if input image seems to have just one color 1.0. | |
output = output.astype('uint16') | |
# Save the output as a NIfTI file | |
output_img = nib.Nifti1Image(output, affine=nifti_img.affine) | |
output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name | |
nib.save(output_img, output_path) | |
# Read the saved file for download | |
with open(output_path, "rb") as f: | |
output_data = f.read() | |
# Download button for NIfTI file | |
st.download_button( | |
label="Download Segmentation Output", | |
data=output_data, | |
file_name='segmentation_output.nii.gz', | |
mime='application/gzip' | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
elif uploaded_file is None: | |
st.info("Awaiting file upload...") | |
elif not process_button: | |
st.info("Click the 'Process' button to start processing.") |