Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import tempfile
|
|
10 |
from pathlib import Path
|
11 |
from skimage.filters import threshold_otsu
|
12 |
import torchio as tio
|
|
|
13 |
|
14 |
def infer_full_vol(tensor, model):
|
15 |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims
|
@@ -75,6 +76,8 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
|
|
75 |
aggregator.add_batch(output, locations)
|
76 |
|
77 |
progress_bar.progress((i + 1) / total_batches)
|
|
|
|
|
78 |
|
79 |
predicted = aggregator.get_output_tensor().squeeze().numpy()
|
80 |
|
@@ -119,6 +122,15 @@ selected_model = st.selectbox("Select a pretrained model:", model_options)
|
|
119 |
mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
|
120 |
selected_mode = st.selectbox("Select the running mode:", mode_options)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# Process button
|
123 |
process_button = st.button("Process")
|
124 |
|
@@ -185,7 +197,7 @@ if uploaded_file is not None and process_button:
|
|
185 |
output = infer_full_vol(tensor, model)
|
186 |
else:
|
187 |
st.info("Running patch-based inference [Default for DS6]...")
|
188 |
-
output = infer_patch_based(tensor, model)
|
189 |
|
190 |
st.success("Processing complete.")
|
191 |
st.write(f"Output tensor shape: `{output.shape}`")
|
@@ -194,7 +206,7 @@ if uploaded_file is not None and process_button:
|
|
194 |
thresh = threshold_otsu(output)
|
195 |
output = output > thresh
|
196 |
except Exception as error:
|
197 |
-
|
198 |
output = output > 0.5 # exception only if input image seems to have just one color 1.0.
|
199 |
output = output.astype('uint16')
|
200 |
|
|
|
10 |
from pathlib import Path
|
11 |
from skimage.filters import threshold_otsu
|
12 |
import torchio as tio
|
13 |
+
import psutil
|
14 |
|
15 |
def infer_full_vol(tensor, model):
|
16 |
tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims
|
|
|
76 |
aggregator.add_batch(output, locations)
|
77 |
|
78 |
progress_bar.progress((i + 1) / total_batches)
|
79 |
+
st.text(f"Processing batch {i + 1} of {total_batches}... ({((i + 1) / total_batches) * 100:.2f}% complete)")
|
80 |
+
st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%")
|
81 |
|
82 |
predicted = aggregator.get_output_tensor().squeeze().numpy()
|
83 |
|
|
|
122 |
mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
|
123 |
selected_mode = st.selectbox("Select the running mode:", mode_options)
|
124 |
|
125 |
+
# Parameters for patch-based inference
|
126 |
+
if selected_mode == "Patch-based inference [Default for DS6]":
|
127 |
+
patch_size = st.number_input("Patch size:", min_value=1, value=64)
|
128 |
+
stride_length = st.number_input("Stride length:", min_value=1, value=32)
|
129 |
+
stride_width = st.number_input("Stride width:", min_value=1, value=32)
|
130 |
+
stride_depth = st.number_input("Stride depth:", min_value=1, value=16)
|
131 |
+
batch_size = st.number_input("Batch size:", min_value=1, value=10)
|
132 |
+
num_worker = st.number_input("Number of workers:", min_value=1, value=2)
|
133 |
+
|
134 |
# Process button
|
135 |
process_button = st.button("Process")
|
136 |
|
|
|
197 |
output = infer_full_vol(tensor, model)
|
198 |
else:
|
199 |
st.info("Running patch-based inference [Default for DS6]...")
|
200 |
+
output = infer_patch_based(tensor, model, patch_size=patch_size, stride_length=stride_length, stride_width=stride_width, stride_depth=stride_depth, batch_size=batch_size, num_worker=num_worker)
|
201 |
|
202 |
st.success("Processing complete.")
|
203 |
st.write(f"Output tensor shape: `{output.shape}`")
|
|
|
206 |
thresh = threshold_otsu(output)
|
207 |
output = output > thresh
|
208 |
except Exception as error:
|
209 |
+
st.error(f"Otsu thresholding failed: {error}. Defaulting to a threshold of 0.5.")
|
210 |
output = output > 0.5 # exception only if input image seems to have just one color 1.0.
|
211 |
output = output.astype('uint16')
|
212 |
|