soumickmj commited on
Commit
ecc769b
·
1 Parent(s): 66e94fc
Files changed (2) hide show
  1. app.py +147 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import numpy as np
4
+ import nibabel as nib
5
+ import torch
6
+ import scipy.io
7
+ from io import BytesIO
8
+ from transformers import AutoModel
9
+ import os
10
+ import tempfile
11
+ from pathlib import Path
12
+ import pandas as pd
13
+
14
+ # Set page configuration
15
+ st.set_page_config(
16
+ page_title="DS6 | Segmenting vessels in 3D MRA-ToF (ideally, 7T)",
17
+ page_icon="🧠",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded",
20
+ )
21
+
22
+ # Sidebar content
23
+ with st.sidebar:
24
+ st.title("Segmenting vessels in the brain from a 3D Magnetic Resonance Angiograph, ideally acquired at 7T | DS6")
25
+ st.markdown("""
26
+ This application allows you to upload a 3D NIfTI file (dims: H x W x D), process it through a pre-trained 3D model (from DS6 and other related works), and download the output as a `.nii.gz` file containing the vessel segmentation.
27
+
28
+ **Instructions**:
29
+ - Upload your 3D NIfTI file (`.nii` or `.nii.gz`). It should be a single-slice cardiac long-axis dynamic CINE scan, where the first dimension represents time.
30
+ - Select a seed value from the dropdown menu.
31
+ - Click the "Process" button to generate the latent factors.
32
+ """)
33
+ st.markdown("---")
34
+ st.markdown("© 2024 Soumick Chatterjee")
35
+
36
+ # Main content
37
+ st.header("From single-slice cardiac long-axis dynamic CINE scan (3D: HxWxD) to 128 latent factors...")
38
+
39
+ # File uploader
40
+ uploaded_file = st.file_uploader(
41
+ "Please upload a 3D NIfTI file (.nii or .nii.gz)",
42
+ type=["nii", "nii.gz"]
43
+ )
44
+
45
+ # Seed selection
46
+ model_options = ["SMILEUHURA_DS6_CamSVD_UNetMSS3D_wDeform"]
47
+ selected_model = st.selectbox("Select a pretrained model:", model_options)
48
+
49
+ # Process button
50
+ process_button = st.button("Process")
51
+
52
+ if uploaded_file is not None and process_button:
53
+ try:
54
+ # Save the uploaded file to a temporary file
55
+ file_extension = ''.join(Path(uploaded_file.name).suffixes)
56
+ with tempfile.NamedTemporaryFile(suffix=file_extension) as tmp_file:
57
+ tmp_file.write(uploaded_file.read())
58
+ tmp_file.flush()
59
+
60
+ # Load the NIfTI file from the temporary file
61
+ nifti_img = nib.load(tmp_file.name)
62
+ data = nifti_img.get_fdata()
63
+
64
+ # Convert to PyTorch tensor
65
+ tensor = torch.from_numpy(data).float()
66
+
67
+ # Ensure it's 3D
68
+ if tensor.ndim != 3:
69
+ st.error("The uploaded NIfTI file is not a 3D volume. Please upload a valid 3D NIfTI file.")
70
+ else:
71
+ # Display input details
72
+ st.success("File successfully uploaded and read.")
73
+ st.write(f"Input tensor shape: `{tensor.shape}`")
74
+ st.write(f"Selected pretrained model: `{selected_model}`")
75
+
76
+ # Add batch and channel dimensions
77
+ tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W]
78
+
79
+ # Construct the model name based on the selected seed
80
+ model_name = f"soumickmj/{selected_model}"
81
+
82
+ # Load the pre-trained model from Hugging Face
83
+ @st.cache_resource
84
+ def load_model(model_name):
85
+ hf_token = os.environ.get('HF_API_TOKEN')
86
+ if hf_token is None:
87
+ st.error("Hugging Face API token is not set. Please set the 'HF_API_TOKEN' environment variable.")
88
+ return None
89
+ try:
90
+ model = AutoModel.from_pretrained(
91
+ model_name,
92
+ trust_remote_code=True,
93
+ use_auth_token=hf_token
94
+ )
95
+ model.eval()
96
+ return model
97
+ except Exception as e:
98
+ st.error(f"Failed to load model: {e}")
99
+ return None
100
+
101
+ with st.spinner('Loading the pre-trained model...'):
102
+ model = load_model(model_name)
103
+ if model is None:
104
+ st.stop() # Stop the app if the model couldn't be loaded
105
+
106
+ # Move model and tensor to CPU (ensure compatibility with Spaces)
107
+ device = torch.device('cpu')
108
+ model = model.to(device)
109
+ tensor = tensor.to(device)
110
+
111
+ # Process the tensor through the model
112
+ with st.spinner('Processing the tensor through the model...'):
113
+ with torch.no_grad():
114
+ output = model.encode(tensor, use_ema=model.config.test_ema)
115
+ if isinstance(output, tuple):
116
+ output = output[0]
117
+ output = output.squeeze(0)
118
+
119
+ st.success("Processing complete.")
120
+ st.write(f"Output tensor shape: `{output.shape}`")
121
+
122
+ # Convert output to NumPy array
123
+ output_np = output.detach().cpu().numpy()
124
+
125
+ # Save the output as a NIfTI file
126
+ output_img = nib.Nifti1Image(output_np, affine=nifti_img.affine)
127
+ output_path = tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False).name
128
+ nib.save(output_img, output_path)
129
+
130
+ # Read the saved file for download
131
+ with open(output_path, "rb") as f:
132
+ output_data = f.read()
133
+
134
+ # Download button for NIfTI file
135
+ st.download_button(
136
+ label="Download Segmentation Output",
137
+ data=output_data,
138
+ file_name='segmentation_output.nii.gz',
139
+ mime='application/gzip'
140
+ )
141
+
142
+ except Exception as e:
143
+ st.error(f"An error occurred: {e}")
144
+ elif uploaded_file is None:
145
+ st.info("Awaiting file upload...")
146
+ elif not process_button:
147
+ st.info("Click the 'Process' button to start processing.")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ nibabel
2
+ torch
3
+ pytorch_lightning
4
+ scipy
5
+ transformers
6
+ torchvision