FrancescoLR commited on
Commit
0496106
·
verified ·
1 Parent(s): 81fd262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -33,7 +33,7 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
33
  Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
34
  The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
35
  """
36
- # Load NIfTI image and get the data
37
  img = nib.load(nifti_path)
38
  data = img.get_fdata()
39
 
@@ -44,18 +44,23 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
44
  # Define half the slice size to extract regions around the center of mass
45
  half_size = slice_size // 2
46
 
47
- # Safely extract slices with boundary checks
48
- def safe_slice(data, center, axis, half_size):
49
  slices = [slice(None)] * 3
50
- slices[axis] = slice(
51
- max(center[axis] - half_size, 0),
52
- min(center[axis] + half_size, data.shape[axis])
53
- )
54
- return data[tuple(slices)]
55
 
56
- axial_slice = safe_slice(data, center, axis=2, half_size=half_size) # Axial (z-axis)
57
- coronal_slice = safe_slice(data, center, axis=1, half_size=half_size) # Coronal (y-axis)
58
- sagittal_slice = safe_slice(data, center, axis=0, half_size=half_size) # Sagittal (x-axis)
 
 
 
 
 
 
 
 
59
 
60
  # Create subplots
61
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
@@ -78,7 +83,6 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
78
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
79
  plt.close()
80
 
81
-
82
  # Function to run nnUNet inference
83
  @spaces.GPU # Decorate the function to allocate GPU for its execution
84
  def run_nnunet_predict(nifti_file):
 
33
  Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
34
  The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
35
  """
36
+ # Load NIfTI image and get the data
37
  img = nib.load(nifti_path)
38
  data = img.get_fdata()
39
 
 
44
  # Define half the slice size to extract regions around the center of mass
45
  half_size = slice_size // 2
46
 
47
+ # Safely extract 2D slices
48
+ def extract_2d_slice(data, center, axis):
49
  slices = [slice(None)] * 3
50
+ slices[axis] = center[axis]
51
+ extracted_slice = data[tuple(slices)]
 
 
 
52
 
53
+ # Crop around the center for the other two dimensions
54
+ other_axes = [i for i in range(3) if i != axis]
55
+ for i in other_axes:
56
+ start = max(center[i] - half_size, 0)
57
+ end = min(center[i] + half_size, data.shape[i])
58
+ extracted_slice = np.take(extracted_slice, range(start, end), axis=i)
59
+ return extracted_slice
60
+
61
+ axial_slice = extract_2d_slice(data, center, axis=2) # Axial (z-axis)
62
+ coronal_slice = extract_2d_slice(data, center, axis=1) # Coronal (y-axis)
63
+ sagittal_slice = extract_2d_slice(data, center, axis=0) # Sagittal (x-axis)
64
 
65
  # Create subplots
66
  fig, axes = plt.subplots(1, 3, figsize=(12, 4))
 
83
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
84
  plt.close()
85
 
 
86
  # Function to run nnUNet inference
87
  @spaces.GPU # Decorate the function to allocate GPU for its execution
88
  def run_nnunet_predict(nifti_file):