FrancescoLR commited on
Commit
121d535
·
verified ·
1 Parent(s): 842b2ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -8,7 +8,7 @@ import nibabel as nib
8
  import matplotlib.pyplot as plt
9
  import spaces # Import spaces for GPU decoration
10
  import numpy as np
11
- from scipy.ndimage import center_of_mass
12
 
13
  # Define paths
14
  MODEL_DIR = "./model" # Local directory to store the downloaded model
@@ -27,18 +27,50 @@ def download_model():
27
  zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
28
  subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
29
  print("Dataset004_WML downloaded and extracted.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
31
  def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
32
  """
33
  Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
34
  The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
35
  """
36
- # Load NIfTI image and get the data
37
  img = nib.load(nifti_path)
38
  data = img.get_fdata()
 
 
 
 
39
 
40
  # Compute the center of mass of non-zero voxels
41
- com = center_of_mass(data > 0)
42
  center = np.round(com).astype(int)
43
 
44
  # Define half the slice size
@@ -67,9 +99,9 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
67
  return padded_slice
68
 
69
  # Extract slices in axial, coronal, and sagittal planes
70
- axial_slice = extract_2d_slice(data, center, axis=2) # Axial (z-axis)
71
- coronal_slice = extract_2d_slice(data, center, axis=1) # Coronal (y-axis)
72
- sagittal_slice = extract_2d_slice(data, center, axis=0) # Sagittal (x-axis)
73
 
74
  # Apply rotations to each slice
75
  axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise
@@ -95,7 +127,7 @@ def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
95
  plt.tight_layout()
96
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
97
  plt.close()
98
-
99
  # Function to run nnUNet inference
100
  @spaces.GPU # Decorate the function to allocate GPU for its execution
101
  def run_nnunet_predict(nifti_file):
 
8
  import matplotlib.pyplot as plt
9
  import spaces # Import spaces for GPU decoration
10
  import numpy as np
11
+ from scipy.ndimage import center_of_mass, zoom
12
 
13
  # Define paths
14
  MODEL_DIR = "./model" # Local directory to store the downloaded model
 
27
  zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
28
  subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
29
  print("Dataset004_WML downloaded and extracted.")
30
+
31
+ def resample_to_isotropic(data, affine, target_spacing=1.0):
32
+ """
33
+ Resamples a 3D NIfTI image to isotropic voxel size.
34
+
35
+ Parameters:
36
+ data (numpy.ndarray): The input 3D image data.
37
+ affine (numpy.ndarray): The affine transformation matrix.
38
+ target_spacing (float): Desired isotropic voxel spacing (in mm).
39
+
40
+ Returns:
41
+ resampled_data (numpy.ndarray): Resampled image data.
42
+ resampled_affine (numpy.ndarray): Updated affine matrix.
43
+ """
44
+ # Extract current voxel dimensions from the affine matrix
45
+ current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))
46
+
47
+ # Compute the scaling factors for resampling
48
+ scaling_factors = current_spacing / target_spacing
49
+
50
+ # Resample the data using zoom
51
+ resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation
52
 
53
+ # Update the affine matrix to reflect the new voxel dimensions
54
+ resampled_affine = affine.copy()
55
+ resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]
56
+
57
+ return resampled_data, resampled_affine
58
+
59
  def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
60
  """
61
  Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
62
  The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
63
  """
64
+ # Load NIfTI image
65
  img = nib.load(nifti_path)
66
  data = img.get_fdata()
67
+ affine = img.affine
68
+
69
+ # Resample the image to 1 mm isotropic
70
+ resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
71
 
72
  # Compute the center of mass of non-zero voxels
73
+ com = center_of_mass(resampled_data > 0)
74
  center = np.round(com).astype(int)
75
 
76
  # Define half the slice size
 
99
  return padded_slice
100
 
101
  # Extract slices in axial, coronal, and sagittal planes
102
+ axial_slice = extract_2d_slice(resampled_data, center, axis=2) # Axial (z-axis)
103
+ coronal_slice = extract_2d_slice(resampled_data, center, axis=1) # Coronal (y-axis)
104
+ sagittal_slice = extract_2d_slice(resampled_data, center, axis=0) # Sagittal (x-axis)
105
 
106
  # Apply rotations to each slice
107
  axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise
 
127
  plt.tight_layout()
128
  plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
129
  plt.close()
130
+
131
  # Function to run nnUNet inference
132
  @spaces.GPU # Decorate the function to allocate GPU for its execution
133
  def run_nnunet_predict(nifti_file):