""" @author: louisblankemeier """ import math import os import shutil import cv2 import nibabel as nib import numpy as np import scipy.ndimage as ndi from scipy.ndimage import zoom from skimage.morphology import ball, binary_erosion from comp2comp.hip.hip_visualization import method_visualizer def compute_rois(medical_volume, segmentation, model, output_dir, save=False): left_femur_mask = segmentation.get_fdata() == model.categories["femur_left"] left_femur_mask = left_femur_mask.astype(np.uint8) right_femur_mask = segmentation.get_fdata() == model.categories["femur_right"] right_femur_mask = right_femur_mask.astype(np.uint8) left_head_roi, left_head_centroid, left_head_hu = get_femural_head_roi( left_femur_mask, medical_volume, output_dir, "left_head" ) right_head_roi, right_head_centroid, right_head_hu = get_femural_head_roi( right_femur_mask, medical_volume, output_dir, "right_head" ) ( left_intertrochanter_roi, left_intertrochanter_centroid, left_intertrochanter_hu, ) = get_femural_head_roi( left_femur_mask, medical_volume, output_dir, "left_intertrochanter" ) ( right_intertrochanter_roi, right_intertrochanter_centroid, right_intertrochanter_hu, ) = get_femural_head_roi( right_femur_mask, medical_volume, output_dir, "right_intertrochanter" ) ( left_neck_roi, left_neck_centroid, left_neck_hu, ) = get_femural_neck_roi( left_femur_mask, medical_volume, left_intertrochanter_roi, left_intertrochanter_centroid, left_head_roi, left_head_centroid, output_dir, ) ( right_neck_roi, right_neck_centroid, right_neck_hu, ) = get_femural_neck_roi( right_femur_mask, medical_volume, right_intertrochanter_roi, right_intertrochanter_centroid, right_head_roi, right_head_centroid, output_dir, ) combined_roi = ( left_head_roi + (right_head_roi) # * 2) + (left_intertrochanter_roi) # * 3) + (right_intertrochanter_roi) # * 4) + (left_neck_roi) # * 5) + (right_neck_roi) # * 6) ) if save: # make roi directory if it doesn't exist parent_output_dir = os.path.dirname(output_dir) roi_output_dir = os.path.join(parent_output_dir, "rois") if not os.path.exists(roi_output_dir): os.makedirs(roi_output_dir) # Convert left ROI to NIfTI left_roi_nifti = nib.Nifti1Image(combined_roi, medical_volume.affine) left_roi_path = os.path.join(roi_output_dir, "roi.nii.gz") nib.save(left_roi_nifti, left_roi_path) shutil.copy( os.path.join( os.path.dirname(os.path.abspath(__file__)), "tunnelvision.ipynb", ), parent_output_dir, ) return { "left_head": { "roi": left_head_roi, "centroid": left_head_centroid, "hu": left_head_hu, }, "right_head": { "roi": right_head_roi, "centroid": right_head_centroid, "hu": right_head_hu, }, "left_intertrochanter": { "roi": left_intertrochanter_roi, "centroid": left_intertrochanter_centroid, "hu": left_intertrochanter_hu, }, "right_intertrochanter": { "roi": right_intertrochanter_roi, "centroid": right_intertrochanter_centroid, "hu": right_intertrochanter_hu, }, "left_neck": { "roi": left_neck_roi, "centroid": left_neck_centroid, "hu": left_neck_hu, }, "right_neck": { "roi": right_neck_roi, "centroid": right_neck_centroid, "hu": right_neck_hu, }, } def get_femural_head_roi( femur_mask, medical_volume, output_dir, anatomy, visualize_method=False, min_pixel_count=20, ): top = np.where(femur_mask.sum(axis=(0, 1)) != 0)[0].max() top_mask = femur_mask[:, :, top] print(f"======== Computing {anatomy} femur ROIs ========") while True: labeled, num_features = ndi.label(top_mask) component_sizes = np.bincount(labeled.ravel()) valid_components = np.where(component_sizes >= min_pixel_count)[0][1:] if len(valid_components) == 2: break top -= 1 if top < 0: print("Two connected components not found in the femur mask.") break top_mask = femur_mask[:, :, top] if len(valid_components) == 2: # Find the center of mass for each connected component center_of_mass_1 = list( ndi.center_of_mass(top_mask, labeled, valid_components[0]) ) center_of_mass_2 = list( ndi.center_of_mass(top_mask, labeled, valid_components[1]) ) # Assign left_center_of_mass to be the center of mass with lowest value in the first dimension if center_of_mass_1[0] < center_of_mass_2[0]: left_center_of_mass = center_of_mass_1 right_center_of_mass = center_of_mass_2 else: left_center_of_mass = center_of_mass_2 right_center_of_mass = center_of_mass_1 print(f"Left center of mass: {left_center_of_mass}") print(f"Right center of mass: {right_center_of_mass}") if anatomy == "left_intertrochanter" or anatomy == "right_head": center_of_mass = left_center_of_mass elif anatomy == "right_intertrochanter" or anatomy == "left_head": center_of_mass = right_center_of_mass coronal_slice = femur_mask[:, round(center_of_mass[1]), :] coronal_image = medical_volume.get_fdata()[:, round(center_of_mass[1]), :] sagittal_slice = femur_mask[round(center_of_mass[0]), :, :] sagittal_image = medical_volume.get_fdata()[round(center_of_mass[0]), :, :] zooms = medical_volume.header.get_zooms() zoom_factor = zooms[2] / zooms[1] coronal_slice = zoom(coronal_slice, (1, zoom_factor), order=1).round() coronal_image = zoom(coronal_image, (1, zoom_factor), order=3).round() sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() centroid = [round(center_of_mass[0]), 0, 0] print(f"Starting centroid: {centroid}") for _ in range(3): sagittal_slice = femur_mask[centroid[0], :, :] sagittal_slice = zoom(sagittal_slice, (1, zoom_factor), order=1).round() centroid[1], centroid[2], radius_sagittal = inscribe_sagittal( sagittal_slice, zoom_factor ) print(f"Centroid after inscribe sagittal: {centroid}") axial_slice = femur_mask[:, :, centroid[2]] if anatomy == "left_intertrochanter" or anatomy == "right_head": axial_slice[round(right_center_of_mass[0]) :, :] = 0 elif anatomy == "right_intertrochanter" or anatomy == "left_head": axial_slice[: round(left_center_of_mass[0]), :] = 0 centroid[0], centroid[1], radius_axial = inscribe_axial(axial_slice) print(f"Centroid after inscribe axial: {centroid}") axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])] sagittal_image = medical_volume.get_fdata()[round(centroid[0]), :, :] sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() if visualize_method: method_visualizer( sagittal_image, axial_image, axial_slice, sagittal_slice, [centroid[2], centroid[1]], radius_sagittal, [centroid[1], centroid[0]], radius_axial, output_dir, anatomy, ) roi = compute_hip_roi(medical_volume, centroid, radius_sagittal, radius_axial) # selem = ndi.generate_binary_structure(3, 1) selem = ball(3) femur_mask_eroded = binary_erosion(femur_mask, selem) roi = roi * femur_mask_eroded roi_eroded = roi.astype(np.uint8) hu = get_mean_roi_hu(medical_volume, roi_eroded) return (roi_eroded, centroid, hu) def get_femural_neck_roi( femur_mask, medical_volume, intertrochanter_roi, intertrochanter_centroid, head_roi, head_centroid, output_dir, ): zooms = medical_volume.header.get_zooms() direction_vector = np.array(head_centroid) - np.array(intertrochanter_centroid) unit_direction_vector = direction_vector / np.linalg.norm(direction_vector) z, y, x = np.where(intertrochanter_roi) intertrochanter_points = np.column_stack((z, y, x)) t_start = np.dot( intertrochanter_points - intertrochanter_centroid, unit_direction_vector ).max() z, y, x = np.where(head_roi) head_points = np.column_stack((z, y, x)) t_end = ( np.linalg.norm(direction_vector) + np.dot(head_points - head_centroid, unit_direction_vector).min() ) z, y, x = np.indices(femur_mask.shape) coordinates = np.stack((z, y, x), axis=-1) distance_to_line_origin = np.dot( coordinates - intertrochanter_centroid, unit_direction_vector ) coordinates_zoomed = coordinates * zooms intertrochanter_centroid_zoomed = np.array(intertrochanter_centroid) * zooms unit_direction_vector_zoomed = unit_direction_vector * zooms distance_to_line = np.linalg.norm( np.cross( coordinates_zoomed - intertrochanter_centroid_zoomed, coordinates_zoomed - (intertrochanter_centroid_zoomed + unit_direction_vector_zoomed), ), axis=-1, ) / np.linalg.norm(unit_direction_vector_zoomed) cylinder_radius = 10 cylinder_mask = ( (distance_to_line <= cylinder_radius) & (distance_to_line_origin >= t_start) & (distance_to_line_origin <= t_end) ) # selem = ndi.generate_binary_structure(3, 1) selem = ball(3) femur_mask_eroded = binary_erosion(femur_mask, selem) roi = cylinder_mask * femur_mask_eroded neck_roi = roi.astype(np.uint8) hu = get_mean_roi_hu(medical_volume, neck_roi) centroid = list( intertrochanter_centroid + unit_direction_vector * (t_start + t_end) / 2 ) centroid = [round(x) for x in centroid] return neck_roi, centroid, hu def compute_hip_roi(img, centroid, radius_sagittal, radius_axial): pixel_spacing = img.header.get_zooms() length_i = radius_axial * 0.75 / pixel_spacing[0] length_j = radius_axial * 0.75 / pixel_spacing[1] length_k = radius_sagittal * 0.75 / pixel_spacing[2] roi = np.zeros(img.get_fdata().shape, dtype=np.uint8) i_lower = math.floor(centroid[0] - length_i) j_lower = math.floor(centroid[1] - length_j) k_lower = math.floor(centroid[2] - length_k) for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1): for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1): for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1): if (i - centroid[0]) ** 2 / length_i**2 + ( j - centroid[1] ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1: roi[i, j, k] = 1 return roi def inscribe_axial(axial_mask): dist_map = cv2.distanceTransform(axial_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) _, radius_axial, _, center_axial = cv2.minMaxLoc(dist_map) center_axial = list(center_axial) left_right_center = round(center_axial[1]) posterior_anterior_center = round(center_axial[0]) return left_right_center, posterior_anterior_center, radius_axial def inscribe_sagittal(sagittal_mask, zoom_factor): dist_map = cv2.distanceTransform(sagittal_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) _, radius_sagittal, _, center_sagittal = cv2.minMaxLoc(dist_map) center_sagittal = list(center_sagittal) posterior_anterior_center = round(center_sagittal[1]) inferior_superior_center = round(center_sagittal[0]) inferior_superior_center = round(inferior_superior_center / zoom_factor) return posterior_anterior_center, inferior_superior_center, radius_sagittal def get_mean_roi_hu(medical_volume, roi): masked_medical_volume = medical_volume.get_fdata() * roi return np.mean(masked_medical_volume[masked_medical_volume != 0])