soumickmj commited on
Commit
23492e1
·
1 Parent(s): 59075ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -10,6 +10,7 @@ import tempfile
10
  from pathlib import Path
11
  from skimage.filters import threshold_otsu
12
  import torchio as tio
 
13
 
14
  def infer_full_vol(tensor, model):
15
  tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims
@@ -75,6 +76,8 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
75
  aggregator.add_batch(output, locations)
76
 
77
  progress_bar.progress((i + 1) / total_batches)
 
 
78
 
79
  predicted = aggregator.get_output_tensor().squeeze().numpy()
80
 
@@ -119,6 +122,15 @@ selected_model = st.selectbox("Select a pretrained model:", model_options)
119
  mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
120
  selected_mode = st.selectbox("Select the running mode:", mode_options)
121
 
 
 
 
 
 
 
 
 
 
122
  # Process button
123
  process_button = st.button("Process")
124
 
@@ -185,7 +197,7 @@ if uploaded_file is not None and process_button:
185
  output = infer_full_vol(tensor, model)
186
  else:
187
  st.info("Running patch-based inference [Default for DS6]...")
188
- output = infer_patch_based(tensor, model)
189
 
190
  st.success("Processing complete.")
191
  st.write(f"Output tensor shape: `{output.shape}`")
@@ -194,7 +206,7 @@ if uploaded_file is not None and process_button:
194
  thresh = threshold_otsu(output)
195
  output = output > thresh
196
  except Exception as error:
197
- print(error)
198
  output = output > 0.5 # exception only if input image seems to have just one color 1.0.
199
  output = output.astype('uint16')
200
 
 
10
  from pathlib import Path
11
  from skimage.filters import threshold_otsu
12
  import torchio as tio
13
+ import psutil
14
 
15
  def infer_full_vol(tensor, model):
16
  tensor = tensor.unsqueeze(0).unsqueeze(0) # Shape: [1, 1, D, H, W] - adding batch and channel dims
 
76
  aggregator.add_batch(output, locations)
77
 
78
  progress_bar.progress((i + 1) / total_batches)
79
+ st.text(f"Processing batch {i + 1} of {total_batches}... ({((i + 1) / total_batches) * 100:.2f}% complete)")
80
+ st.text(f"CPU usage: {psutil.cpu_percent()}% | RAM usage: {psutil.virtual_memory().percent}%")
81
 
82
  predicted = aggregator.get_output_tensor().squeeze().numpy()
83
 
 
122
  mode_options = ["Full volume inference", "Patch-based inference [Default for DS6]"]
123
  selected_mode = st.selectbox("Select the running mode:", mode_options)
124
 
125
+ # Parameters for patch-based inference
126
+ if selected_mode == "Patch-based inference [Default for DS6]":
127
+ patch_size = st.number_input("Patch size:", min_value=1, value=64)
128
+ stride_length = st.number_input("Stride length:", min_value=1, value=32)
129
+ stride_width = st.number_input("Stride width:", min_value=1, value=32)
130
+ stride_depth = st.number_input("Stride depth:", min_value=1, value=16)
131
+ batch_size = st.number_input("Batch size:", min_value=1, value=10)
132
+ num_worker = st.number_input("Number of workers:", min_value=1, value=2)
133
+
134
  # Process button
135
  process_button = st.button("Process")
136
 
 
197
  output = infer_full_vol(tensor, model)
198
  else:
199
  st.info("Running patch-based inference [Default for DS6]...")
200
+ output = infer_patch_based(tensor, model, patch_size=patch_size, stride_length=stride_length, stride_width=stride_width, stride_depth=stride_depth, batch_size=batch_size, num_worker=num_worker)
201
 
202
  st.success("Processing complete.")
203
  st.write(f"Output tensor shape: `{output.shape}`")
 
206
  thresh = threshold_otsu(output)
207
  output = output > thresh
208
  except Exception as error:
209
+ st.error(f"Otsu thresholding failed: {error}. Defaulting to a threshold of 0.5.")
210
  output = output > 0.5 # exception only if input image seems to have just one color 1.0.
211
  output = output.astype('uint16')
212