Spaces:
Running
Running
import sys | |
import os | |
import glob | |
import SimpleITK as sitk | |
from tqdm import tqdm | |
import random | |
from HD_BET.hd_bet import hd_bet | |
import argparse | |
import torch | |
def brain_extraction(input_dir, output_dir, device): | |
""" | |
Brain extraction using HDBET package (UNet based DL method) | |
Args: | |
input_dir {path} -- input directory for registered images | |
output_dir {path} -- output directory for brain extracted images | |
Returns: | |
Brain images | |
""" | |
print("Running brain extraction...") | |
print(f"Input directory: {input_dir}") | |
print(f"Output directory: {output_dir}") | |
# Run HD-BET directly with the output directory | |
hd_bet(input_dir, output_dir, device=device, mode='fast', tta=0) | |
print('Brain extraction complete!') | |
print("\nContents of output directory after brain extraction:") | |
print(os.listdir(output_dir)) | |
def registration(input_dir, output_dir, temp_img, interp_type='linear'): | |
""" | |
MRI registration with SimpleITK | |
Args: | |
input_dir {path} -- Directory containing input images | |
output_dir {path} -- Directory to save registered images | |
temp_img {str} -- Registration image template | |
Returns: | |
The sitk image object -- nii.gz | |
""" | |
# Read the template image | |
fixed_img = sitk.ReadImage(temp_img, sitk.sitkFloat32) | |
# Track problematic files | |
IDs = [] | |
print("Preloading step...") | |
for img_dir in tqdm(sorted(glob.glob(input_dir + '/*.nii.gz'))): | |
ID = img_dir.split('/')[-1].split('.')[0] | |
try: | |
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) | |
except Exception as e: | |
IDs.append(ID) | |
print(f"Error loading {ID}: {e}") | |
count = 0 | |
print("Registering images...") | |
list_of_files = sorted(glob.glob(input_dir + '/*.nii.gz')) | |
for img_dir in tqdm(list_of_files): | |
ID = img_dir.split('/')[-1].split('.')[0] | |
if ID in IDs: | |
print(f'Skipping problematic file: {ID}') | |
continue | |
if "_mask" in ID: | |
continue | |
print(f"Processing image {count + 1}: {ID}") | |
try: | |
# Read and preprocess moving image | |
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) | |
moving_img = sitk.N4BiasFieldCorrection(moving_img) | |
# Resample fixed image to 1mm isotropic | |
old_size = fixed_img.GetSize() | |
old_spacing = fixed_img.GetSpacing() | |
new_spacing = (1, 1, 1) | |
new_size = [ | |
int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), | |
int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), | |
int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) | |
] | |
# Set interpolation type | |
if interp_type == 'linear': | |
interp_type = sitk.sitkLinear | |
elif interp_type == 'bspline': | |
interp_type = sitk.sitkBSpline | |
elif interp_type == 'nearest_neighbor': | |
interp_type = sitk.sitkNearestNeighbor | |
# Resample fixed image | |
resample = sitk.ResampleImageFilter() | |
resample.SetOutputSpacing(new_spacing) | |
resample.SetSize(new_size) | |
resample.SetOutputOrigin(fixed_img.GetOrigin()) | |
resample.SetOutputDirection(fixed_img.GetDirection()) | |
resample.SetInterpolator(interp_type) | |
resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) | |
resample.SetOutputPixelType(sitk.sitkFloat32) | |
fixed_img = resample.Execute(fixed_img) | |
# Initialize transform | |
transform = sitk.CenteredTransformInitializer( | |
fixed_img, | |
moving_img, | |
sitk.Euler3DTransform(), | |
sitk.CenteredTransformInitializerFilter.GEOMETRY) | |
# Set up registration method | |
registration_method = sitk.ImageRegistrationMethod() | |
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) | |
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) | |
registration_method.SetMetricSamplingPercentage(0.01) | |
registration_method.SetInterpolator(sitk.sitkLinear) | |
registration_method.SetOptimizerAsGradientDescent( | |
learningRate=1.0, | |
numberOfIterations=100, | |
convergenceMinimumValue=1e-6, | |
convergenceWindowSize=10) | |
registration_method.SetOptimizerScalesFromPhysicalShift() | |
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) | |
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) | |
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() | |
registration_method.SetInitialTransform(transform) | |
# Execute registration | |
final_transform = registration_method.Execute(fixed_img, moving_img) | |
# Apply transform and save registered image | |
moving_img_resampled = sitk.Resample( | |
moving_img, | |
fixed_img, | |
final_transform, | |
sitk.sitkLinear, | |
0.0, | |
moving_img.GetPixelID()) | |
# Save with _0000 suffix as required by HD-BET | |
output_filename = os.path.join(output_dir, f"{ID}_0000.nii.gz") | |
sitk.WriteImage(moving_img_resampled, output_filename) | |
print(f"Saved registered image to: {output_filename}") | |
count += 1 | |
except Exception as e: | |
print(f"Error processing {ID}: {e}") | |
continue | |
print(f"Successfully registered {count} images.") | |
# Debug information | |
print(f"Contents of output directory {output_dir}:") | |
print(os.listdir(output_dir)) | |
return count > 0 | |
def main(temp_img, input_dir, output_dir): | |
""" | |
Main function to process brain MRI images | |
Args: | |
temp_img {str} -- Path to template image | |
input_dir {str} -- Path to input directory containing images | |
output_dir {str} -- Path to output directory for results | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
# set device | |
device = "0" if torch.cuda.is_available() else "cpu" | |
# Create temporary directory for intermediate results | |
temp_reg_dir = os.path.join(output_dir, 'temp_registered') | |
os.makedirs(temp_reg_dir, exist_ok=True) | |
print("Starting brain MRI preprocessing...") | |
# REgistration | |
print("\nStep 1: Image Registration") | |
success = registration( | |
input_dir=input_dir, | |
output_dir=temp_reg_dir, | |
temp_img=temp_img | |
) | |
if not success: | |
print("Registration failed! No images were processed successfully.") | |
return | |
print("\nChecking temporary directory contents:") | |
print(os.listdir(temp_reg_dir)) | |
# skullstripping | |
print("\nStep 2: Brain Extraction") | |
brain_extraction( | |
input_dir=temp_reg_dir, | |
output_dir=output_dir, | |
device=device | |
) | |
# Clean up temporary directory | |
import shutil | |
shutil.rmtree(temp_reg_dir) | |
print("\nPreprocessing complete! Final results saved in:", output_dir) | |
print("Final preprocessed files:") | |
print(os.listdir(output_dir)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Process brain MRI registration and skull stripping.") | |
parser.add_argument("--temp_img", type=str, required=True, help="Path to the atlas template image.") | |
parser.add_argument("--input_dir", type=str, required=True, help="Path to the input images directory.") | |
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the processed images.") | |
args = parser.parse_args() | |
main(temp_img=args.temp_img, input_dir=args.input_dir, output_dir=args.output_dir) |