""" @author: louisblankemeier """ import math import os import shutil import zipfile from pathlib import Path from time import time from typing import Union import nibabel as nib import numpy as np import pandas as pd import wget from PIL import Image from totalsegmentatorv2.python_api import totalsegmentator from comp2comp.inference_class_base import InferenceClass from comp2comp.io import io_utils from comp2comp.models.models import Models from comp2comp.spine import spine_utils from comp2comp.visualization.dicom import to_dicom # from totalsegmentator.libs import ( # download_pretrained_weights, # nostdout, # setup_nnunet, # ) class SpineSegmentation(InferenceClass): """Spine segmentation.""" def __init__(self, model_name, save=True): super().__init__() self.model_name = model_name self.save_segmentations = save def __call__(self, inference_pipeline): # inference_pipeline.dicom_series_path = self.input_path self.output_dir = inference_pipeline.output_dir self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") if not os.path.exists(self.output_dir_segmentations): os.makedirs(self.output_dir_segmentations) self.model_dir = inference_pipeline.model_dir # seg, mv = self.spine_seg( # os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), # self.output_dir_segmentations + "spine.nii.gz", # inference_pipeline.model_dir, # ) os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir seg = totalsegmentator( input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), task_ids=[292], ml=True, nr_thr_resamp=1, nr_thr_saving=6, fast=False, nora_tag="None", preview=False, task="total", # roi_subset=[ # "vertebrae_T12", # "vertebrae_L1", # "vertebrae_L2", # "vertebrae_L3", # "vertebrae_L4", # "vertebrae_L5", # ], roi_subset=None, statistics=False, radiomics=False, crop_path=None, body_seg=False, force_split=False, output_type="nifti", quiet=False, verbose=False, test=0, skip_saving=True, device="gpu", license_number=None, statistics_exclude_masks_at_border=True, no_derived_masks=False, v1_order=False, ) mv = nib.load( os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") ) # inference_pipeline.segmentation = nib.load( # os.path.join(self.output_dir_segmentations, "segmentation.nii") # ) inference_pipeline.segmentation = seg inference_pipeline.medical_volume = mv inference_pipeline.save_segmentations = self.save_segmentations return {} def setup_nnunet_c2c(self, model_dir: Union[str, Path]): """Adapted from TotalSegmentator.""" model_dir = Path(model_dir) config_dir = model_dir / Path("." + self.model_name) (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( exist_ok=True, parents=True ) (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) weights_dir = config_dir / "nnunet/results" self.weights_dir = weights_dir os.environ["nnUNet_raw_data_base"] = str( weights_dir ) # not needed, just needs to be an existing directory os.environ["nnUNet_preprocessed"] = str( weights_dir ) # not needed, just needs to be an existing directory os.environ["RESULTS_FOLDER"] = str(weights_dir) def download_spine_model(self, model_dir: Union[str, Path]): download_dir = Path( os.path.join( self.weights_dir, "nnUNet/3d_fullres/Task252_Spine/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", ) ) fold_0_path = download_dir / "fold_0" if not os.path.exists(fold_0_path): download_dir.mkdir(parents=True, exist_ok=True) wget.download( "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/fold_0.zip", out=os.path.join(download_dir, "fold_0.zip"), ) with zipfile.ZipFile( os.path.join(download_dir, "fold_0.zip"), "r" ) as zip_ref: zip_ref.extractall(download_dir) os.remove(os.path.join(download_dir, "fold_0.zip")) wget.download( "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/plans.pkl", out=os.path.join(download_dir, "plans.pkl"), ) print("Spine model downloaded.") else: print("Spine model already downloaded.") def spine_seg( self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir ): """Run spine segmentation. Args: input_path (Union[str, Path]): Input path. output_path (Union[str, Path]): Output path. """ print("Segmenting spine...") st = time() os.environ["SCRATCH"] = self.model_dir os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir # Setup nnunet model = "3d_fullres" folds = [0] trainer = "nnUNetTrainerV2_ep4000_nomirror" crop_path = None task_id = [252] if self.model_name == "ts_spine": setup_nnunet() download_pretrained_weights(task_id[0]) elif self.model_name == "stanford_spine_v0.0.1": self.setup_nnunet_c2c(model_dir) self.download_spine_model(model_dir) else: raise ValueError("Invalid model name.") if not self.save_segmentations: output_path = None from totalsegmentator.nnunet import nnUNet_predict_image with nostdout(): img, seg = nnUNet_predict_image( input_path, output_path, task_id, model=model, folds=folds, trainer=trainer, tta=False, multilabel_image=True, resample=1.5, crop=None, crop_path=crop_path, task_name="total", nora_tag="None", preview=False, nr_threads_resampling=1, nr_threads_saving=6, quiet=False, verbose=False, test=0, ) end = time() # Log total time for spine segmentation print(f"Total time for spine segmentation: {end-st:.2f}s.") if self.model_name == "stanford_spine_v0.0.1": seg_data = seg.get_fdata() # subtract 17 from seg values except for 0 seg_data = np.where(seg_data == 0, 0, seg_data - 17) seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) return seg, img class AxialCropper(InferenceClass): """Crop the CT image (medical_volume) and segmentation based on user-specified lower and upper levels of the spine. """ def __init__(self, lower_level: str = "L5", upper_level: str = "L1", save=True): """ Args: lower_level (str, optional): Lower level of the spine. Defaults to "L5". upper_level (str, optional): Upper level of the spine. Defaults to "L1". save (bool, optional): Save cropped image and segmentation. Defaults to True. Raises: ValueError: If lower_level or upper_level is not a valid spine level. """ super().__init__() self.lower_level = lower_level self.upper_level = upper_level ts_spine_full_model = Models.model_from_name("ts_spine_full") categories = ts_spine_full_model.categories try: self.lower_level_index = categories[self.lower_level] self.upper_level_index = categories[self.upper_level] except KeyError: raise ValueError("Invalid spine level.") from None self.save = save def __call__(self, inference_pipeline): """ First dim goes from L to R. Second dim goes from P to A. Third dim goes from I to S. """ segmentation = inference_pipeline.segmentation segmentation_data = segmentation.get_fdata() upper_level_index = np.where(segmentation_data == self.upper_level_index)[ 2 ].max() lower_level_index = np.where(segmentation_data == self.lower_level_index)[ 2 ].min() segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index] inference_pipeline.segmentation = segmentation medical_volume = inference_pipeline.medical_volume medical_volume = medical_volume.slicer[ :, :, lower_level_index:upper_level_index ] inference_pipeline.medical_volume = medical_volume if self.save: nib.save( segmentation, os.path.join( inference_pipeline.output_dir, "segmentations", "spine.nii.gz" ), ) nib.save( medical_volume, os.path.join( inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz", ), ) return {} class SpineComputeROIs(InferenceClass): def __init__(self, spine_model): super().__init__() self.spine_model_name = spine_model self.spine_model_type = Models.model_from_name(self.spine_model_name) def __call__(self, inference_pipeline): # Compute ROIs inference_pipeline.spine_model_type = self.spine_model_type (spine_hus, rois, segmentation_hus, centroids_3d) = spine_utils.compute_rois( inference_pipeline.segmentation, inference_pipeline.medical_volume, self.spine_model_type, ) inference_pipeline.spine_hus = spine_hus inference_pipeline.segmentation_hus = segmentation_hus inference_pipeline.rois = rois inference_pipeline.centroids_3d = centroids_3d return {} class SpineMetricsSaver(InferenceClass): """Save metrics to a CSV file.""" def __init__(self): super().__init__() def __call__(self, inference_pipeline): """Save metrics to a CSV file.""" self.spine_hus = inference_pipeline.spine_hus self.seg_hus = inference_pipeline.segmentation_hus self.output_dir = inference_pipeline.output_dir self.csv_output_dir = os.path.join(self.output_dir, "metrics") if not os.path.exists(self.csv_output_dir): os.makedirs(self.csv_output_dir, exist_ok=True) self.save_results() if hasattr(inference_pipeline, "dicom_ds"): if not os.path.exists(os.path.join(self.output_dir, "dicom_metadata.csv")): io_utils.write_dicom_metadata_to_csv( inference_pipeline.dicom_ds, os.path.join(self.output_dir, "dicom_metadata.csv"), ) return {} def save_results(self): """Save results to a CSV file.""" df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"]) for i, level in enumerate(self.spine_hus): hu = self.spine_hus[level] seg_hu = self.seg_hus[level] row = [level, hu, seg_hu] df.loc[i] = row df = df.iloc[::-1] df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False) class SpineFindDicoms(InferenceClass): def __init__(self): super().__init__() def __call__(self, inference_pipeline): inferior_superior_centers = spine_utils.find_spine_dicoms( inference_pipeline.centroids_3d, ) spine_utils.save_nifti_select_slices( inference_pipeline.output_dir, inferior_superior_centers ) inference_pipeline.dicom_file_paths = [ str(center) for center in inferior_superior_centers ] inference_pipeline.names = list(inference_pipeline.rois.keys()) inference_pipeline.dicom_file_names = list(inference_pipeline.rois.keys()) inference_pipeline.inferior_superior_centers = inferior_superior_centers return {} class SpineCoronalSagittalVisualizer(InferenceClass): def __init__(self, format="png"): super().__init__() self.format = format def __call__(self, inference_pipeline): output_path = inference_pipeline.output_dir spine_model_type = inference_pipeline.spine_model_type img_sagittal, img_coronal = spine_utils.visualize_coronal_sagittal_spine( inference_pipeline.segmentation.get_fdata(), list(inference_pipeline.rois.values()), inference_pipeline.medical_volume.get_fdata(), list(inference_pipeline.centroids_3d.values()), output_path, spine_hus=inference_pipeline.spine_hus, seg_hus=inference_pipeline.segmentation_hus, model_type=spine_model_type, pixel_spacing=inference_pipeline.pixel_spacing_list, format=self.format, ) inference_pipeline.spine_vis_sagittal = img_sagittal inference_pipeline.spine_vis_coronal = img_coronal inference_pipeline.spine = True if not inference_pipeline.save_segmentations: shutil.rmtree(os.path.join(output_path, "segmentations")) return {} class SpineReport(InferenceClass): def __init__(self, format="png"): super().__init__() self.format = format def __call__(self, inference_pipeline): sagittal_image = inference_pipeline.spine_vis_sagittal coronal_image = inference_pipeline.spine_vis_coronal # concatenate these numpy arrays laterally img = np.concatenate((coronal_image, sagittal_image), axis=1) output_path = os.path.join( inference_pipeline.output_dir, "images", "spine_report" ) if self.format == "png": im = Image.fromarray(img) im.save(output_path + ".png") elif self.format == "dcm": to_dicom(img, output_path + ".dcm") return {} class SpineMuscleAdiposeTissueReport(InferenceClass): """Spine muscle adipose tissue report class.""" def __init__(self): super().__init__() self.image_files = [ "spine_coronal.png", "spine_sagittal.png", "T12.png", "L1.png", "L2.png", "L3.png", "L4.png", "L5.png", ] def __call__(self, inference_pipeline): image_dir = Path(inference_pipeline.output_dir) / "images" self.generate_panel(image_dir) return {} def generate_panel(self, image_dir: Union[str, Path]): """Generate panel. Args: image_dir (Union[str, Path]): Path to the image directory. """ image_files = [os.path.join(image_dir, path) for path in self.image_files] # construct a list which includes only the images that exist image_files = [path for path in image_files if os.path.exists(path)] im_cor = Image.open(image_files[0]) im_sag = Image.open(image_files[1]) im_cor_width = int(im_cor.width / im_cor.height * 512) num_muscle_fat_cols = math.ceil((len(image_files) - 2) / 2) width = (8 + im_cor_width + 8) + ((512 + 8) * num_muscle_fat_cols) height = 1048 new_im = Image.new("RGB", (width, height)) index = 2 for j in range(8, height, 520): for i in range(8 + im_cor_width + 8, width, 520): try: im = Image.open(image_files[index]) im.thumbnail((512, 512)) new_im.paste(im, (i, j)) index += 1 im.close() except Exception: continue im_cor.thumbnail((im_cor_width, 512)) new_im.paste(im_cor, (8, 8)) im_sag.thumbnail((im_cor_width, 512)) new_im.paste(im_sag, (8, 528)) new_im.save(os.path.join(image_dir, "spine_muscle_adipose_tissue_report.png")) im_cor.close() im_sag.close() new_im.close()