import sys import os import glob import SimpleITK as sitk from tqdm import tqdm import random from HD_BET.hd_bet import hd_bet import argparse import torch def brain_extraction(input_dir, output_dir, device): """ Brain extraction using HDBET package (UNet based DL method) Args: input_dir {path} -- input directory for registered images output_dir {path} -- output directory for brain extracted images Returns: Brain images """ print("Running brain extraction...") print(f"Input directory: {input_dir}") print(f"Output directory: {output_dir}") # Run HD-BET directly with the output directory hd_bet(input_dir, output_dir, device=device, mode='fast', tta=0) print('Brain extraction complete!') print("\nContents of output directory after brain extraction:") print(os.listdir(output_dir)) def registration(input_dir, output_dir, temp_img, interp_type='linear'): """ MRI registration with SimpleITK Args: input_dir {path} -- Directory containing input images output_dir {path} -- Directory to save registered images temp_img {str} -- Registration image template Returns: The sitk image object -- nii.gz """ # Read the template image fixed_img = sitk.ReadImage(temp_img, sitk.sitkFloat32) # Track problematic files IDs = [] print("Preloading step...") for img_dir in tqdm(sorted(glob.glob(input_dir + '/*.nii.gz'))): ID = img_dir.split('/')[-1].split('.')[0] try: moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) except Exception as e: IDs.append(ID) print(f"Error loading {ID}: {e}") count = 0 print("Registering images...") list_of_files = sorted(glob.glob(input_dir + '/*.nii.gz')) for img_dir in tqdm(list_of_files): ID = img_dir.split('/')[-1].split('.')[0] if ID in IDs: print(f'Skipping problematic file: {ID}') continue if "_mask" in ID: continue print(f"Processing image {count + 1}: {ID}") try: # Read and preprocess moving image moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) moving_img = sitk.N4BiasFieldCorrection(moving_img) # Resample fixed image to 1mm isotropic old_size = fixed_img.GetSize() old_spacing = fixed_img.GetSpacing() new_spacing = (1, 1, 1) new_size = [ int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) ] # Set interpolation type if interp_type == 'linear': interp_type = sitk.sitkLinear elif interp_type == 'bspline': interp_type = sitk.sitkBSpline elif interp_type == 'nearest_neighbor': interp_type = sitk.sitkNearestNeighbor # Resample fixed image resample = sitk.ResampleImageFilter() resample.SetOutputSpacing(new_spacing) resample.SetSize(new_size) resample.SetOutputOrigin(fixed_img.GetOrigin()) resample.SetOutputDirection(fixed_img.GetDirection()) resample.SetInterpolator(interp_type) resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) resample.SetOutputPixelType(sitk.sitkFloat32) fixed_img = resample.Execute(fixed_img) # Initialize transform transform = sitk.CenteredTransformInitializer( fixed_img, moving_img, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY) # Set up registration method registration_method = sitk.ImageRegistrationMethod() registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) registration_method.SetMetricSamplingPercentage(0.01) registration_method.SetInterpolator(sitk.sitkLinear) registration_method.SetOptimizerAsGradientDescent( learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10) registration_method.SetOptimizerScalesFromPhysicalShift() registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() registration_method.SetInitialTransform(transform) # Execute registration final_transform = registration_method.Execute(fixed_img, moving_img) # Apply transform and save registered image moving_img_resampled = sitk.Resample( moving_img, fixed_img, final_transform, sitk.sitkLinear, 0.0, moving_img.GetPixelID()) # Save with _0000 suffix as required by HD-BET output_filename = os.path.join(output_dir, f"{ID}_0000.nii.gz") sitk.WriteImage(moving_img_resampled, output_filename) print(f"Saved registered image to: {output_filename}") count += 1 except Exception as e: print(f"Error processing {ID}: {e}") continue print(f"Successfully registered {count} images.") # Debug information print(f"Contents of output directory {output_dir}:") print(os.listdir(output_dir)) return count > 0 def main(temp_img, input_dir, output_dir): """ Main function to process brain MRI images Args: temp_img {str} -- Path to template image input_dir {str} -- Path to input directory containing images output_dir {str} -- Path to output directory for results """ os.makedirs(output_dir, exist_ok=True) # set device device = "0" if torch.cuda.is_available() else "cpu" # Create temporary directory for intermediate results temp_reg_dir = os.path.join(output_dir, 'temp_registered') os.makedirs(temp_reg_dir, exist_ok=True) print("Starting brain MRI preprocessing...") # REgistration print("\nStep 1: Image Registration") success = registration( input_dir=input_dir, output_dir=temp_reg_dir, temp_img=temp_img ) if not success: print("Registration failed! No images were processed successfully.") return print("\nChecking temporary directory contents:") print(os.listdir(temp_reg_dir)) # skullstripping print("\nStep 2: Brain Extraction") brain_extraction( input_dir=temp_reg_dir, output_dir=output_dir, device=device ) # Clean up temporary directory import shutil shutil.rmtree(temp_reg_dir) print("\nPreprocessing complete! Final results saved in:", output_dir) print("Final preprocessed files:") print(os.listdir(output_dir)) if __name__ == '__main__': parser = argparse.ArgumentParser(description="Process brain MRI registration and skull stripping.") parser.add_argument("--temp_img", type=str, required=True, help="Path to the atlas template image.") parser.add_argument("--input_dir", type=str, required=True, help="Path to the input images directory.") parser.add_argument("--output_dir", type=str, required=True, help="Path to save the processed images.") args = parser.parse_args() main(temp_img=args.temp_img, input_dir=args.input_dir, output_dir=args.output_dir)