diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f3b6cb2af3d1648521052d57b6b792780b34d8c5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Comp2Comp-main/figures/aaa_segmentation_video.gif filter=lfs diff=lfs merge=lfs -text +Comp2Comp-main/figures/liver_spleen_pancreas_example.png filter=lfs diff=lfs merge=lfs -text +Comp2Comp-main/figures/spine_muscle_adipose_tissue_example.png filter=lfs diff=lfs merge=lfs -text diff --git a/Comp2Comp-main/.github/workflows/format.yml b/Comp2Comp-main/.github/workflows/format.yml new file mode 100644 index 0000000000000000000000000000000000000000..baa5c86dbb25657f274c51bc2ca02a358b632fad --- /dev/null +++ b/Comp2Comp-main/.github/workflows/format.yml @@ -0,0 +1,33 @@ +name: Autoformat code + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ 'main' ] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Format code + run: | + pip install black + black . + - name: Sort imports + run: | + pip install isort + isort . + - name: Remove unused imports + run: | + pip install autoflake + autoflake --in-place --remove-all-unused-imports --remove-unused-variables --recursive . + - name: Commit changes + uses: EndBug/add-and-commit@v4 + with: + author_name: ${{ github.actor }} + author_email: ${{ github.actor }}@users.noreply.github.com + message: "Autoformat code" + add: "." + branch: ${{ github.ref }} \ No newline at end of file diff --git a/Comp2Comp-main/.gitignore b/Comp2Comp-main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ba05edcb57a49622ccca550575bfbf95def830b --- /dev/null +++ b/Comp2Comp-main/.gitignore @@ -0,0 +1,71 @@ +# Ignore project files +**/.idea +**/.DS_Store +**/.vscode + +# Ignore cache +**/__pycache__ + +# Ignore egg files +**/*.egg-info + +# Docs build files +docs/_build + +# Ignore tensorflow logs +**/tf_log + +# Ignore results +**/pik_data +**/preds + +# Ignore test_data +**/test_data +**/testing_data +**/sample_data +**/test_results + +# Ignore images +**/model_imgs + +# Ignore data visualization scripts/images +**/data_visualization +**/OAI-iMorphics + +# temp files +._* +# ignore checkpoint files +**/.ipynb_checkpoints/ +**/.comp2comp/ + +# ignore cross validation files +*.cv + +# ignore yml file +*.yml +*.yaml +!.github/workflows/format.yml + +# ignore images +*.png +!panel_example.png +!logo.png +# except for pngs in the figures folder +!figures/*.png + +# ignore any weights files +weights/ + +# preferences file +comp2comp/preferences.yaml + +# model directory +**/.comp2comp_model_dir/ + +# slurm outputs +**/slurm/ + +# ignore outputs file +**/outputs/ + +**/models/ diff --git a/Comp2Comp-main/Dockerfile b/Comp2Comp-main/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..bab13a2a8ab523f501e3dbb43a1b41039f811caa --- /dev/null +++ b/Comp2Comp-main/Dockerfile @@ -0,0 +1,5 @@ +FROM python:3.8 +COPY . /Comp2Comp +WORKDIR /Comp2Comp +RUN pip install -e . +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y \ No newline at end of file diff --git a/Comp2Comp-main/LICENSE b/Comp2Comp-main/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/Comp2Comp-main/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Comp2Comp-main/README.md b/Comp2Comp-main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..62ae27c3d0d5c23c9e17c6ae0437aac15d72d3ba --- /dev/null +++ b/Comp2Comp-main/README.md @@ -0,0 +1,197 @@ +# Comp2Comp +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/StanfordMIMI/Comp2Comp/format.yml?branch=master) +[![Documentation Status](https://readthedocs.org/projects/comp2comp/badge/?version=latest)](https://comp2comp.readthedocs.io/en/latest/?badge=latest) + +[**Paper**](https://arxiv.org/abs/2302.06568) +| [**Installation**](#installation) +| [**Basic Usage**](#basic_usage) +| [**Inference Pipelines**](#basic_usage) +| [**Contribute**](#contribute) +| [**Citation**](#citation) + +Comp2Comp is a library for extracting clinical insights from computed tomography scans. + +## Installation + +```bash +git clone https://github.com/StanfordMIMI/Comp2Comp/ + +# Install script requires Anaconda/Miniconda. +cd Comp2Comp && bin/install.sh +``` + +Alternatively, Comp2Comp can be installed with `pip`: +```bash +git clone https://github.com/StanfordMIMI/Comp2Comp/ +cd Comp2Comp +conda create -n c2c_env python=3.8 +conda activate c2c_env +pip install -e . +``` + +For installing on the Apple M1 chip, see [these instructions](https://github.com/StanfordMIMI/Comp2Comp/blob/master/docs/Local%20Implementation%20%40%20M1%20arm64%20Silicon.md). + +## Basic Usage + +```bash +bin/C2C -i +``` + +For running on slurm, modify the above commands as follow: +```bash +bin/C2C-slurm -i +``` + +## Inference Pipelines + +We have designed Comp2Comp to be highly extensible and to enable the development of complex clinically-relevant applications. We observed that many clinical applications require chaining several machine learning or other computational modules together to generate complex insights. The inference pipeline system is designed to make this easy. Furthermore, we seek to make the code readable and modular, so that the community can easily contribute to the project. + +The [`InferencePipeline` class](comp2comp/inference_pipeline.py) is used to create inference pipelines, which are made up of a sequence of [`InferenceClass` objects](comp2comp/inference_class_base.py). When the `InferencePipeline` object is called, it sequentially calls the `InferenceClasses` that were provided to the constructor. + +The first argument of the `__call__` function of `InferenceClass` must be the `InferencePipeline` object. This allows each `InferenceClass` object to access or set attributes of the `InferencePipeline` object that can be accessed by the subsequent `InferenceClass` objects in the pipeline. Each `InferenceClass` object should return a dictionary where the keys of the dictionary should match the keyword arguments of the subsequent `InferenceClass's` `__call__` function. If an `InferenceClass` object only sets attributes of the `InferencePipeline` object but does not return any value, an empty dictionary can be returned. + +Below are the inference pipelines currently supported by Comp2Comp. + +## Spine Bone Mineral Density from 3D Trabecular Bone Regions at T12-L5 + +### Usage +```bash +bin/C2C spine -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output Image +

+ +

+ +## End-to-End Spine, Muscle, and Adipose Tissue Analysis at T12-L5 + +### Usage +```bash +bin/C2C spine_muscle_adipose_tissue -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output Image +

+ +

+ +## AAA Segmentation and Maximum Diameter Measurement + +### Usage +```bash +bin/C2C aaa -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output Image (slice with largest diameter) +

+ +

+ +
+ +| Example Output Video | Example Output Graph | +|-----------------------------|----------------------------| +|

|

| + +
+ +## Contrast Phase Detection + +### Usage +```bash +bin/C2C contrast_phase -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. +- This package has extra dependencies. To install those, run: +```bash +cd Comp2Comp +pip install -e '.[contrast_phase]' +``` + +## 3D Analysis of Liver, Spleen, and Pancreas + +### Usage +```bash +bin/C2C liver_spleen_pancreas -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output Image +

+ +

+ +## 3D Analysis of the Femur + +### Usage +```bash +bin/C2C hip -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output Image +

+ +

+ +## Abdominal Aortic Calcification Segmentation + +### Usage +```bash +bin/C2C aortic_calcium -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +### Example Output +``` +Statistics on aortic calcifications: +Total number: 7 +Total volume (cm³): 0.348 +Mean HU: 570.3+/-85.8 +Median HU: 544.2+/-85.3 +Max HU: 981.7+/-266.4 +Mean volume (cm³): 0.005+/-0.059 +Median volume (cm³): 0.022 +Max volume (cm³): 0.184 +Min volume (cm³): 0.005 +``` + +## Pipeline that runs all currently implemented pipelines + +### Usage +```bash +bin/C2C all -i +``` +- input_path should contain a DICOM series or subfolders that contain DICOM series. + +## Contribute + +We welcome all pull requests. If you have any issues, suggestions, or feedback, please open a new issue. + +## Citation + +``` +@article{blankemeier2023comp2comp, + title={Comp2Comp: Open-Source Body Composition Assessment on Computed Tomography}, + author={Blankemeier, Louis and Desai, Arjun and Chaves, Juan Manuel Zambrano and Wentland, Andrew and Yao, Sally and Reis, Eduardo and Jensen, Malte and Bahl, Bhanushree and Arora, Khushboo and Patel, Bhavik N and others}, + journal={arXiv preprint arXiv:2302.06568}, + year={2023} +} +``` + +In addition to Comp2Comp, please consider citing TotalSegmentator: +``` +@article{wasserthal2022totalsegmentator, + title={TotalSegmentator: robust segmentation of 104 anatomical structures in CT images}, + author={Wasserthal, Jakob and Meyer, Manfred and Breit, Hanns-Christian and Cyriac, Joshy and Yang, Shan and Segeroth, Martin}, + journal={arXiv preprint arXiv:2208.05868}, + year={2022} +} +``` + + diff --git a/Comp2Comp-main/bin/C2C b/Comp2Comp-main/bin/C2C new file mode 100644 index 0000000000000000000000000000000000000000..d71220eab838c5ec55b171668f33f2fbb35e3b25 --- /dev/null +++ b/Comp2Comp-main/bin/C2C @@ -0,0 +1,276 @@ +#!/usr/bin/env python +import argparse +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" + +from comp2comp.aaa import aaa +from comp2comp.aortic_calcium import ( + aortic_calcium, + aortic_calcium_visualization, +) +from comp2comp.contrast_phase.contrast_phase import ContrastPhaseDetection +from comp2comp.hip import hip +from comp2comp.inference_pipeline import InferencePipeline +from comp2comp.io import io +from comp2comp.liver_spleen_pancreas import ( + liver_spleen_pancreas, + liver_spleen_pancreas_visualization, +) +from comp2comp.muscle_adipose_tissue import ( + muscle_adipose_tissue, + muscle_adipose_tissue_visualization, +) +from comp2comp.spine import spine +from comp2comp.utils import orientation +from comp2comp.utils.process import process_3d + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +### AAA Pipeline + +def AAAPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + AxialCropperPipelineBuilder(path, args), + aaa.AortaSegmentation(), + aaa.AortaDiameter(), + aaa.AortaMetricsSaver() + ] + ) + return pipeline + +def MuscleAdiposeTissuePipelineBuilder(args): + pipeline = InferencePipeline( + [ + muscle_adipose_tissue.MuscleAdiposeTissueSegmentation( + 16, args.muscle_fat_model + ), + muscle_adipose_tissue.MuscleAdiposeTissuePostProcessing(), + muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(), + muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(), + muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(), + muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(), + ] + ) + return pipeline + + +def MuscleAdiposeTissueFullPipelineBuilder(args): + pipeline = InferencePipeline( + [io.DicomFinder(args.input_path), MuscleAdiposeTissuePipelineBuilder(args)] + ) + return pipeline + + +def SpinePipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + spine.SpineSegmentation(args.spine_model, save=True), + orientation.ToCanonical(), + spine.SpineComputeROIs(args.spine_model), + spine.SpineMetricsSaver(), + spine.SpineCoronalSagittalVisualizer(format="png"), + spine.SpineReport(format="png"), + ] + ) + return pipeline + + +def AxialCropperPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + spine.SpineSegmentation(args.spine_model), + orientation.ToCanonical(), + spine.AxialCropper(lower_level="L5", upper_level="L1", save=True), + ] + ) + return pipeline + + +def SpineMuscleAdiposeTissuePipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + SpinePipelineBuilder(path, args), + spine.SpineFindDicoms(), + MuscleAdiposeTissuePipelineBuilder(args), + spine.SpineMuscleAdiposeTissueReport(), + ] + ) + return pipeline + + +def LiverSpleenPancreasPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + liver_spleen_pancreas.LiverSpleenPancreasSegmentation(), + orientation.ToCanonical(), + liver_spleen_pancreas_visualization.LiverSpleenPancreasVisualizer(), + liver_spleen_pancreas_visualization.LiverSpleenPancreasMetricsPrinter(), + ] + ) + return pipeline + + +def AorticCalciumPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + spine.SpineSegmentation(model_name="ts_spine"), + orientation.ToCanonical(), + aortic_calcium.AortaSegmentation(), + orientation.ToCanonical(), + aortic_calcium.AorticCalciumSegmentation(), + aortic_calcium.AorticCalciumMetrics(), + aortic_calcium_visualization.AorticCalciumVisualizer(), + aortic_calcium_visualization.AorticCalciumPrinter(), + ] + ) + return pipeline + + +def ContrastPhasePipelineBuilder(path, args): + pipeline = InferencePipeline([io.DicomToNifti(path), ContrastPhaseDetection(path)]) + return pipeline + + +def HipPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + hip.HipSegmentation(args.hip_model), + orientation.ToCanonical(), + hip.HipComputeROIs(args.hip_model), + hip.HipMetricsSaver(), + hip.HipVisualizer(), + ] + ) + return pipeline + + +def AllPipelineBuilder(path, args): + pipeline = InferencePipeline( + [ + io.DicomToNifti(path), + SpineMuscleAdiposeTissuePipelineBuilder(path, args), + LiverSpleenPancreasPipelineBuilder(path, args), + HipPipelineBuilder(path, args), + ] + ) + return pipeline + + +def argument_parser(): + base_parser = argparse.ArgumentParser(add_help=False) + base_parser.add_argument("--input_path", "-i", type=str, required=True) + base_parser.add_argument("--output_path", "-o", type=str) + base_parser.add_argument("--save_segmentations", action="store_true") + base_parser.add_argument("--overwrite_outputs", action="store_true") + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="pipeline", help="Pipeline to run") + + # Add the help option to each subparser + muscle_adipose_tissue_parser = subparsers.add_parser( + "muscle_adipose_tissue", parents=[base_parser] + ) + muscle_adipose_tissue_parser.add_argument( + "--muscle_fat_model", default="abCT_v0.0.1", type=str + ) + + # Spine + spine_parser = subparsers.add_parser("spine", parents=[base_parser]) + spine_parser.add_argument("--spine_model", default="ts_spine", type=str) + + # Spine + muscle + fat + spine_muscle_adipose_tissue_parser = subparsers.add_parser( + "spine_muscle_adipose_tissue", parents=[base_parser] + ) + spine_muscle_adipose_tissue_parser.add_argument( + "--muscle_fat_model", default="stanford_v0.0.2", type=str + ) + spine_muscle_adipose_tissue_parser.add_argument( + "--spine_model", default="ts_spine", type=str + ) + + liver_spleen_pancreas = subparsers.add_parser( + "liver_spleen_pancreas", parents=[base_parser] + ) + + aortic_calcium = subparsers.add_parser("aortic_calcium", parents=[base_parser]) + + contrast_phase_parser = subparsers.add_parser( + "contrast_phase", parents=[base_parser] + ) + + hip_parser = subparsers.add_parser("hip", parents=[base_parser]) + hip_parser.add_argument( + "--hip_model", + default="ts_hip", + type=str, + ) + + # AAA + aorta_diameter_parser = subparsers.add_parser("aaa", help="aorta diameter", parents=[base_parser]) + + aorta_diameter_parser.add_argument( + "--aorta_model", + default="ts_spine", + type=str, + help="aorta model to use for inference", + ) + + aorta_diameter_parser.add_argument( + "--spine_model", + default="ts_spine", + type=str, + help="spine model to use for inference", + ) + + all_parser = subparsers.add_parser("all", parents=[base_parser]) + all_parser.add_argument( + "--muscle_fat_model", + default="abCT_v0.0.1", + type=str, + ) + all_parser.add_argument( + "--spine_model", + default="ts_spine", + type=str, + ) + all_parser.add_argument( + "--hip_model", + default="ts_hip", + type=str, + ) + return parser + + +def main(): + args = argument_parser().parse_args() + if args.pipeline == "spine_muscle_adipose_tissue": + process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder) + elif args.pipeline == "spine": + process_3d(args, SpinePipelineBuilder) + elif args.pipeline == "contrast_phase": + process_3d(args, ContrastPhasePipelineBuilder) + elif args.pipeline == "liver_spleen_pancreas": + process_3d(args, LiverSpleenPancreasPipelineBuilder) + elif args.pipeline == "aortic_calcium": + process_3d(args, AorticCalciumPipelineBuilder) + elif args.pipeline == "hip": + process_3d(args, HipPipelineBuilder) + elif args.pipeline == "aaa": + process_3d(args, AAAPipelineBuilder) + elif args.pipeline == "all": + process_3d(args, AllPipelineBuilder) + else: + raise AssertionError("{} command not supported".format(args.action)) + + +if __name__ == "__main__": + main() diff --git a/Comp2Comp-main/bin/C2C-slurm b/Comp2Comp-main/bin/C2C-slurm new file mode 100644 index 0000000000000000000000000000000000000000..9b0daaa9fd4a22326cf57dd45c46c658f033d92a --- /dev/null +++ b/Comp2Comp-main/bin/C2C-slurm @@ -0,0 +1,46 @@ +#!/usr/bin/env python +import os +import pipes +import subprocess +import sys +from pathlib import Path + +exec_file = sys.argv[0].split("-")[0] +command = exec_file + " " + " ".join([pipes.quote(s) for s in sys.argv[1:]]) + + +def submit_command(command): + subprocess.run(command.split(" "), check=True, capture_output=False) + + +def python_submit(command, node="siena"): + bash_file = open("./slurm.sh", "w") + bash_file.write(f"#!/bin/bash\n{command}") + bash_file.close() + slurm_output_path = Path("./slurm/") + slurm_output_path.mkdir(parents=True, exist_ok=True) + + try: + if node is None: + command = "sbatch --ntasks=1 --cpus-per-task=1 --output ./slurm/slurm-%j.out \ + --mem-per-cpu=8G -p gpu --gpus 1 --time=1:00:00 slurm.sh" + submit_command(command) + print(f'Submitted the command --- "{command}" --- to slurm.') + else: + command = f"sbatch --ntasks=1 --cpus-per-task=1 --output ./slurm/slurm-%j.out \ + --nodelist={node} --mem-per-cpu=8G -p gpu --gpus 1 --time=1:00:00 slurm.sh" + submit_command(command) + print(f'Submitted the command --- "{command}" --- to slurm.') + except subprocess.CalledProcessError: + if node == None: + command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --mem=128000 --time=100-00:00:00 slurm.sh " + submit_command(command) + print(f'Submitted the command --- "{command}" --- to slurm.') + else: + command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=128000 --time=100-00:00:00 slurm.sh" + submit_command(command) + print(f'Submitted the command --- "{command}" --- to slurm.') + os.remove("./slurm.sh") + + +python_submit(command) diff --git a/Comp2Comp-main/bin/install.sh b/Comp2Comp-main/bin/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d00a70e27b6d59f9b21a7e008a687122f48418e --- /dev/null +++ b/Comp2Comp-main/bin/install.sh @@ -0,0 +1,146 @@ +#!/bin/bash + +# ============================================================================== +# Auto-installation for abCTSeg for Linux and Mac machines. +# This setup script is adapted from DOSMA: +# https://github.com/ad12/DOSMA +# ============================================================================== + +BIN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +ANACONDA_KEYWORD="anaconda" +ANACONDA_DOWNLOAD_URL="https://www.anaconda.com/distribution/" +MINICONDA_KEYWORD="miniconda" + +# FIXME: Update the name. +ABCT_ENV_NAME="c2c_env" + +hasAnaconda=0 +updateEnv=0 +updatePath=1 +pythonVersion="3.9" +cudaVersion="" + +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + -h|--help) + echo "Batch evaluation with ss_recon" + echo "" + echo "Usage:" + echo " --python Python version" + echo " -f, --force Force environment update" + exit + ;; + --python) + pythonVersion=$2 + shift # past argument + shift # past value + ;; + --cuda) + cudaVersion=$2 + shift # past argument + shift # past value + ;; + -f|--force) + updateEnv=1 + shift # past argument + ;; + *) + echo "Unknown option: $key" + exit 1 + ;; + esac +done + +# Initial setup +source ~/.bashrc +currDir=`pwd` + + +if echo $PATH | grep -q $ANACONDA_KEYWORD; then + hasAnaconda=1 + echo "Conda found in path" +fi + +if echo $PATH | grep -q $MINICONDA_KEYWORD; then + hasAnaconda=1 + echo "Miniconda found in path" +fi + +if [[ $hasAnaconda -eq 0 ]]; then + echo "Anaconda/Miniconda not installed - install from $ANACONDA_DOWNLOAD_URL" + openURL $ANACONDA_DOWNLOAD_URL + exit 125 +fi + +# Hacky way of finding the conda base directory +condaPath=`which conda` +condaPath=`dirname ${condaPath}` +condaPath=`dirname ${condaPath}` +# Source conda +source $condaPath/etc/profile.d/conda.sh + +# Check if OS is supported +if [[ "$OSTYPE" != "linux-gnu" && "$OSTYPE" != "darwin"* ]]; then + echo "Only Linux and MacOS are supported" + exit 125 +fi + +# Create Anaconda environment (dosma_env) +if [[ `conda env list | grep $ABCT_ENV_NAME` ]]; then + if [[ ${updateEnv} -eq 0 ]]; then + echo "Environment '${ABCT_ENV_NAME}' is installed. Run 'conda activate ${ABCT_ENV_NAME}' to get started." + exit 0 + else + conda env remove -n $ABCT_ENV_NAME + conda create -y -n $ABCT_ENV_NAME python=3.9 + fi +else + conda create -y -n $ABCT_ENV_NAME python=3.9 +fi + +conda activate $ABCT_ENV_NAME + +# Install tensorflow and keras +# https://www.tensorflow.org/install/source#gpu +# pip install tensorflow + +# Install pytorch +# FIXME: PyTorch has to be installed with pip to respect setup.py files from nn UNet +# pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu +# if [[ "$OSTYPE" == "darwin"* ]]; then +# # Mac +# if [[ $cudaVersion != "" ]]; then +# # CPU +# echo "Cannot install PyTorch with CUDA support on Mac" +# exit 1 +# fi +# conda install -y pytorch torchvision torchaudio -c pytorch +# else +# # Linux +# if [[ $cudaVersion == "" ]]; then +# cudatoolkit="cpuonly" +# else +# cudatoolkit="cudatoolkit=${cudaVersion}" +# fi +# conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 $cudatoolkit -c pytorch +# fi + +# Install detectron2 +# FIXME: Remove dependency on detectron2 +#pip3 install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html + +# Install totalSegmentor +# FIXME: Add this to the setup.py file +# pip3 install git+https://github.com/StanfordMIMI/TotalSegmentator.git + +# cd $currDir/.. +# echo $currDir +# exit 1 + +pip install -e . --no-cache-dir + +echo "" +echo "" +echo "Run 'conda activate ${ABCT_ENV_NAME}' to get started." \ No newline at end of file diff --git a/Comp2Comp-main/comp2comp/__init__.py b/Comp2Comp-main/comp2comp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9218ecb767f1020773fbb8aa1b7c3cdfde21a59c --- /dev/null +++ b/Comp2Comp-main/comp2comp/__init__.py @@ -0,0 +1,8 @@ +from .utils.env import setup_environment + +setup_environment() + + +# This line will be programatically read/write by setup.py. +# Leave them at the bottom of this file and don't touch them. +__version__ = "0.0.1" diff --git a/Comp2Comp-main/comp2comp/aaa/aaa.py b/Comp2Comp-main/comp2comp/aaa/aaa.py new file mode 100644 index 0000000000000000000000000000000000000000..fd82d764da23dd2003694937aa08038d2d50f8b4 --- /dev/null +++ b/Comp2Comp-main/comp2comp/aaa/aaa.py @@ -0,0 +1,424 @@ +import math +import operator +import os +import zipfile +from pathlib import Path +from time import time +from tkinter import Tcl +from typing import Union + +import cv2 +import matplotlib.pyplot as plt +import moviepy.video.io.ImageSequenceClip +import nibabel as nib +import numpy as np +import pandas as pd +import pydicom +import wget +from totalsegmentator.libs import nostdout + +from comp2comp.inference_class_base import InferenceClass + + +class AortaSegmentation(InferenceClass): + """Spine segmentation.""" + + def __init__(self, save=True): + super().__init__() + self.model_name = "totalsegmentator" + self.save_segmentations = save + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + seg, mv = self.spine_seg( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + self.output_dir_segmentations + "spine.nii.gz", + inference_pipeline.model_dir, + ) + + seg = seg.get_fdata() + medical_volume = mv.get_fdata() + + axial_masks = [] + ct_image = [] + + for i in range(seg.shape[2]): + axial_masks.append(seg[:, :, i]) + + for i in range(medical_volume.shape[2]): + ct_image.append(medical_volume[:, :, i]) + + # Save input axial slices to pipeline + inference_pipeline.ct_image = ct_image + + # Save aorta masks to pipeline + inference_pipeline.axial_masks = axial_masks + + return {} + + def setup_nnunet_c2c(self, model_dir: Union[str, Path]): + """Adapted from TotalSegmentator.""" + + model_dir = Path(model_dir) + config_dir = model_dir / Path("." + self.model_name) + (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( + exist_ok=True, parents=True + ) + (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) + weights_dir = config_dir / "nnunet/results" + self.weights_dir = weights_dir + + os.environ["nnUNet_raw_data_base"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["nnUNet_preprocessed"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["RESULTS_FOLDER"] = str(weights_dir) + + def download_spine_model(self, model_dir: Union[str, Path]): + download_dir = Path( + os.path.join( + self.weights_dir, + "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", + ) + ) + print(download_dir) + fold_0_path = download_dir / "fold_0" + if not os.path.exists(fold_0_path): + download_dir.mkdir(parents=True, exist_ok=True) + wget.download( + "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", + out=os.path.join(download_dir, "fold_0.zip"), + ) + with zipfile.ZipFile( + os.path.join(download_dir, "fold_0.zip"), "r" + ) as zip_ref: + zip_ref.extractall(download_dir) + os.remove(os.path.join(download_dir, "fold_0.zip")) + wget.download( + "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", + out=os.path.join(download_dir, "plans.pkl"), + ) + print("Spine model downloaded.") + else: + print("Spine model already downloaded.") + + def spine_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run spine segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting spine...") + st = time() + os.environ["SCRATCH"] = self.model_dir + + print(self.model_dir) + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [253] + + self.setup_nnunet_c2c(model_dir) + self.download_spine_model(model_dir) + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + img, seg = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag="None", + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=False, + test=0, + ) + end = time() + + # Log total time for spine segmentation + print(f"Total time for spine segmentation: {end-st:.2f}s.") + + seg_data = seg.get_fdata() + seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) + + return seg, img + + +class AortaDiameter(InferenceClass): + def __init__(self): + super().__init__() + + def normalize_img(self, img: np.ndarray) -> np.ndarray: + """Normalize the image. + Args: + img (np.ndarray): Input image. + Returns: + np.ndarray: Normalized image. + """ + return (img - img.min()) / (img.max() - img.min()) + + def __call__(self, inference_pipeline): + axial_masks = ( + inference_pipeline.axial_masks + ) # list of 2D numpy arrays of shape (512, 512) + ct_img = ( + inference_pipeline.ct_image + ) # 3D numpy array of shape (512, 512, num_axial_slices) + + # image output directory + output_dir = inference_pipeline.output_dir + output_dir_slices = os.path.join(output_dir, "images/slices/") + if not os.path.exists(output_dir_slices): + os.makedirs(output_dir_slices) + + output_dir = inference_pipeline.output_dir + output_dir_summary = os.path.join(output_dir, "images/summary/") + if not os.path.exists(output_dir_summary): + os.makedirs(output_dir_summary) + + DICOM_PATH = inference_pipeline.dicom_series_path + dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0]) + + dicom.PhotometricInterpretation = "YBR_FULL" + pixel_conversion = dicom.PixelSpacing + print("Pixel conversion: " + str(pixel_conversion)) + RATIO_PIXEL_TO_MM = pixel_conversion[0] + + SLICE_COUNT = dicom["InstanceNumber"].value + print(SLICE_COUNT) + + SLICE_COUNT = len(ct_img) + diameterDict = {} + + for i in range(len(ct_img)): + mask = axial_masks[i].astype("uint8") + + img = ct_img[i] + + img = np.clip(img, -300, 1800) + img = self.normalize_img(img) * 255.0 + img = img.reshape((img.shape[0], img.shape[1], 1)) + img = np.tile(img, (1, 1, 3)) + + contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + + if len(contours) != 0: + areas = [cv2.contourArea(c) for c in contours] + sorted_areas = np.sort(areas) + + areas = [cv2.contourArea(c) for c in contours] + sorted_areas = np.sort(areas) + contours = contours[areas.index(sorted_areas[-1])] + + img.copy() + + back = img.copy() + cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) + + alpha = 0.25 + img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0) + + ellipse = cv2.fitEllipse(contours) + (xc, yc), (d1, d2), angle = ellipse + + cv2.ellipse(img, ellipse, (0, 255, 0), 1) + + xc, yc = ellipse[0] + cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) + + rmajor = max(d1, d2) / 2 + rminor = min(d1, d2) / 2 + + ### Draw major axes + + if angle > 90: + angle = angle - 90 + else: + angle = angle + 90 + print(angle) + xtop = xc + math.cos(math.radians(angle)) * rmajor + ytop = yc + math.sin(math.radians(angle)) * rmajor + xbot = xc + math.cos(math.radians(angle + 180)) * rmajor + ybot = yc + math.sin(math.radians(angle + 180)) * rmajor + cv2.line( + img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3 + ) + + ### Draw minor axes + + if angle > 90: + angle = angle - 90 + else: + angle = angle + 90 + print(angle) + x1 = xc + math.cos(math.radians(angle)) * rminor + y1 = yc + math.sin(math.radians(angle)) * rminor + x2 = xc + math.cos(math.radians(angle + 180)) * rminor + y2 = yc + math.sin(math.radians(angle + 180)) * rminor + cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3) + + # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) + pixel_length = rminor * 2 + + print("Pixel_length_minor: " + str(pixel_length)) + + area_px = cv2.contourArea(contours) + area_mm = round(area_px * RATIO_PIXEL_TO_MM) + area_cm = area_mm / 10 + + diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM) + diameter_cm = diameter_mm / 10 + + diameterDict[(SLICE_COUNT - (i))] = diameter_cm + + img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + + h, w, c = img.shape + lbls = [ + "Area (mm): " + str(area_mm) + "mm", + "Area (cm): " + str(area_cm) + "cm", + "Diameter (mm): " + str(diameter_mm) + "mm", + "Diameter (cm): " + str(diameter_cm) + "cm", + "Slice: " + str(SLICE_COUNT - (i)), + ] + font = cv2.FONT_HERSHEY_SIMPLEX + + scale = 0.03 + fontScale = min(w, h) / (25 / scale) + + cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) + + cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) + + cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) + + cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) + + cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) + + cv2.imwrite( + output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img + ) + + plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b") + + plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") + + plt.xlabel("Slice Number") + + plt.ylabel("Diameter Measurement (cm)") + plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500) + + print(diameterDict) + print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) + print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) + + inference_pipeline.max_diameter = diameterDict[ + max(diameterDict.items(), key=operator.itemgetter(1))[0] + ] + + img = ct_img[ + SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0]) + ] + img = np.clip(img, -300, 1800) + img = self.normalize_img(img) * 255.0 + img = img.reshape((img.shape[0], img.shape[1], 1)) + img2 = np.tile(img, (1, 1, 3)) + img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) + + img1 = cv2.imread( + output_dir_slices + + "slice" + + str(max(diameterDict.items(), key=operator.itemgetter(1))[0]) + + ".png" + ) + + border_size = 3 + img1 = cv2.copyMakeBorder( + img1, + top=border_size, + bottom=border_size, + left=border_size, + right=border_size, + borderType=cv2.BORDER_CONSTANT, + value=[0, 244, 0], + ) + img2 = cv2.copyMakeBorder( + img2, + top=border_size, + bottom=border_size, + left=border_size, + right=border_size, + borderType=cv2.BORDER_CONSTANT, + value=[244, 0, 0], + ) + + vis = np.concatenate((img2, img1), axis=1) + cv2.imwrite(output_dir_summary + "out.png", vis) + + image_folder = output_dir_slices + fps = 20 + image_files = [ + os.path.join(image_folder, img) + for img in Tcl().call("lsort", "-dict", os.listdir(image_folder)) + if img.endswith(".png") + ] + clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( + image_files, fps=fps + ) + clip.write_videofile(output_dir_summary + "aaa.mp4") + + return {} + + +class AortaMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + """Save metrics to a CSV file.""" + self.max_diameter = inference_pipeline.max_diameter + self.dicom_series_path = inference_pipeline.dicom_series_path + self.output_dir = inference_pipeline.output_dir + self.csv_output_dir = os.path.join(self.output_dir, "metrics") + if not os.path.exists(self.csv_output_dir): + os.makedirs(self.csv_output_dir, exist_ok=True) + self.save_results() + return {} + + def save_results(self): + """Save results to a CSV file.""" + _, filename = os.path.split(self.dicom_series_path) + data = [[filename, str(self.max_diameter)]] + df = pd.DataFrame(data, columns=["Filename", "Max Diameter"]) + df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) diff --git a/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium.py b/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9df732b6ffd4dc1052d254a5f8b125af005aae --- /dev/null +++ b/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu Apr 20 20:36:05 2023 + +@author: maltejensen +""" +import os +import time +from pathlib import Path +from typing import Union + +import numpy as np +from scipy import ndimage +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) + +from comp2comp.inference_class_base import InferenceClass + + +class AortaSegmentation(InferenceClass): + """Aorta segmentation.""" + + def __init__(self): + super().__init__() + # self.input_path = input_path + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + inference_pipeline.output_dir_segmentations = os.path.join( + self.output_dir, "segmentations/" + ) + + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + mv, seg = self.aorta_seg( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + self.output_dir_segmentations + "organs.nii.gz", + inference_pipeline.model_dir, + ) + # the medical volume is already set by the spine segmentation model + # the toCanonical methods looks for "segmentation", so it's overridden + inference_pipeline.spine_segmentation = inference_pipeline.segmentation + inference_pipeline.segmentation = seg + + return {} + + def aorta_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run organ segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting aorta...") + st = time.time() + os.environ["SCRATCH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [251] + + setup_nnunet() + download_pretrained_weights(task_id[0]) + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + seg, mvs = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag="None", + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=True, + test=0, + ) + end = time.time() + + # Log total time for spine segmentation + print(f"Total time for aorta segmentation: {end-st:.2f}s.") + + return seg, mvs + + +class AorticCalciumSegmentation(InferenceClass): + """Segmentaiton of aortic calcium""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + ct = inference_pipeline.medical_volume.get_fdata() + aorta_mask = inference_pipeline.segmentation.get_fdata() == 7 + spine_mask = inference_pipeline.spine_segmentation.get_fdata() > 0 + + inference_pipeline.calc_mask = self.detectCalcifications( + ct, aorta_mask, exclude_mask=spine_mask, remove_size=3 + ) + + self.output_dir = inference_pipeline.output_dir + self.output_dir_images_organs = os.path.join(self.output_dir, "images/") + inference_pipeline.output_dir_images_organs = self.output_dir_images_organs + + if not os.path.exists(self.output_dir_images_organs): + os.makedirs(self.output_dir_images_organs) + + # np.save(os.path.join(self.output_dir_images_organs, 'ct.npy'), ct) + # np.save(os.path.join(self.output_dir_images_organs, "aorta_mask.npy"), aorta_mask) + # np.save(os.path.join(self.output_dir_images_organs, "spine_mask.npy"), spine_mask) + + # np.save( + # os.path.join(self.output_dir_images_organs, "calcium_mask.npy"), + # inference_pipeline.calc_mask, + # ) + # np.save( + # os.path.join(self.output_dir_images_organs, "ct_scan.npy"), + # inference_pipeline.medical_volume.get_fdata(), + # ) + + return {} + + def detectCalcifications( + self, + ct, + aorta_mask, + exclude_mask=None, + return_dilated_mask=False, + dilation=(3, 1), + dilation_iteration=4, + return_dilated_exclude=False, + dilation_exclude_mask=(3, 1), + dilation_iteration_exclude=3, + show_time=False, + num_std=3, + remove_size=None, + verbose=False, + exclude_center_aorta=True, + return_eroded_aorta=False, + aorta_erode_iteration=6, + ): + """ + Function that takes in a CT image and aorta segmentation (and optionally volumes to use + for exclusion of the segmentations), And returns a mask of the segmented calcifications + (and optionally other volumes). The calcium threshold is adapative and uses the median + of the CT points inside the aorta together with one standard devidation to the left, as + this is more stable. The num_std is multiplied with the distance between the median + and the one standard deviation mark, and can be used to control the threshold. + + Args: + ct (array): CT image. + aorta_mask (array): Mask of the aorta. + exclude_mask (array, optional): + Mask for structures to exclude e.g. spine. Defaults to None. + return_dilated_mask (bool, optional): + Return the dilated aorta mask. Defaults to False. + dilation (list, optional): + Structuring element for aorta dilation. Defaults to (3,1). + dilation_iteration (int, optional): + Number of iterations for the strcturing element. Defaults to 4. + return_dilated_exclude (bool, optional): + Return the dilated exclusio mask. Defaults to False. + dilation_exclude_mask (list, optional): + Structering element for the exclusio mask. Defaults to (3,1). + dilation_iteration_exclude (int, optional): + Number of iterations for the strcturing element. Defaults to 3. + show_time (bool, optional): + Show time for each operation. Defaults to False. + num_std (float, optional): + How many standard deviations out the threshold will be set at. Defaults to 3. + remove_size (int, optional): + Remove foci under a certain size. Warning: quite slow. Defaults to None. + verbose (bool, optional): + Give verbose feedback on operations. Defaults to False. + exclude_center_aorta (bool, optional): + Use eroded aorta to exclude center of the aorta. Defaults to True. + return_eroded_aorta (bool, optional): + Return the eroded center aorta. Defaults to False. + aorta_erode_iteration (int, optional): + Number of iterations for the strcturing element. Defaults to 6. + + Returns: + results: array of only the mask is returned, or dict if other volumes are also returned. + + """ + + def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): + """ + Perform the dilation on the smallest slice that will fit the + segmentation + """ + margin = 2 if num_iteration is None else num_iteration + 1 + + x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] + x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin + y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] + y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin + + if operation == "dilate": + mask_slice = ndimage.binary_dilation( + input_mask[x_start:x_end, y_start:y_end, :], structure=struct + ).astype(np.int8) + elif operation == "erode": + mask_slice = ndimage.binary_erosion( + input_mask[x_start:x_end, y_start:y_end, :], structure=struct + ).astype(np.int8) + + output_mask = input_mask.copy() + + output_mask[x_start:x_end, y_start:y_end, :] = mask_slice + + return output_mask + + # remove parts that are not the abdominal aorta + labelled_aorta, num_classes = ndimage.label(aorta_mask) + if num_classes > 1: + if verbose: + print("Removing {} parts".format(num_classes - 1)) + + aorta_vols = [] + + for i in range(1, num_classes + 1): + aorta_vols.append((labelled_aorta == i).sum()) + + biggest_idx = np.argmax(aorta_vols) + 1 + aorta_mask[labelled_aorta != biggest_idx] = 0 + + # Get aortic CT point to set adaptive threshold + aorta_ct_points = ct[aorta_mask == 1] + + # equal to one standard deviation to the left of the curve + quant = 0.158 + quantile_median_dist = np.median(aorta_ct_points) - np.quantile( + aorta_ct_points, q=quant + ) + calc_thres = np.median(aorta_ct_points) + quantile_median_dist * num_std + + t0 = time.time() + + if dilation is not None: + struct = ndimage.generate_binary_structure(*dilation) + if dilation_iteration is not None: + struct = ndimage.iterate_structure(struct, dilation_iteration) + aorta_dilated = slicedDilationOrErosion( + aorta_mask, + struct=struct, + num_iteration=dilation_iteration, + operation="dilate", + ) + + if show_time: + print("dilation mask time: {:.2f}".format(time.time() - t0)) + + t0 = time.time() + calc_mask = np.logical_and(aorta_dilated == 1, ct >= calc_thres) + if show_time: + print("find calc time: {:.2f}".format(time.time() - t0)) + + if exclude_center_aorta: + t0 = time.time() + + struct = ndimage.generate_binary_structure(3, 1) + struct = ndimage.iterate_structure(struct, aorta_erode_iteration) + + aorta_eroded = slicedDilationOrErosion( + aorta_mask, + struct=struct, + num_iteration=aorta_erode_iteration, + operation="erode", + ) + calc_mask = calc_mask * (aorta_eroded == 0) + if show_time: + print("exclude center aorta time: {:.2f} sec".format(time.time() - t0)) + + t0 = time.time() + if exclude_mask is not None: + if dilation_exclude_mask is not None: + struct_exclude = ndimage.generate_binary_structure( + *dilation_exclude_mask + ) + if dilation_iteration_exclude is not None: + struct_exclude = ndimage.iterate_structure( + struct_exclude, dilation_iteration_exclude + ) + + exclude_mask = slicedDilationOrErosion( + exclude_mask, + struct=struct_exclude, + num_iteration=dilation_iteration_exclude, + operation="dilate", + ) + + if show_time: + print("exclude dilation time: {:.2f}".format(time.time() - t0)) + + t0 = time.time() + calc_mask = calc_mask * (exclude_mask == 0) + if show_time: + print("exclude time: {:.2f}".format(time.time() - t0)) + + if remove_size is not None: + t0 = time.time() + + labels, num_features = ndimage.label(calc_mask) + + counter = 0 + for n in range(1, num_features + 1): + idx_tmp = labels == n + if idx_tmp.sum() <= remove_size: + calc_mask[idx_tmp] = 0 + counter += 1 + + if show_time: + print("Size exclusion time: {:.1f} sec".format(time.time() - t0)) + if verbose: + print("Excluded {} foci under {}".format(counter, remove_size)) + + if not all([return_dilated_mask, return_dilated_exclude]): + return calc_mask.astype(np.int8) + else: + results = {} + results["calc_mask"] = calc_mask.astype(np.int8) + if return_dilated_mask: + results["dilated_mask"] = aorta_dilated + if return_dilated_exclude: + results["dilated_exclude"] = exclude_mask + if return_eroded_aorta: + results["aorta_eroded"] = aorta_eroded + + results["threshold"] = calc_thres + + return results + + +class AorticCalciumMetrics(InferenceClass): + """Calculate metrics for the aortic calcifications""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + calc_mask = inference_pipeline.calc_mask + + inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[ + "pixdim" + ][1:4] + # divided with 10 to get in cm + inference_pipeline.vol_per_pixel = np.prod(inference_pipeline.pix_dims / 10) + + # count statistics for individual calcifications + labelled_calc, num_lesions = ndimage.label(calc_mask) + + metrics = { + "volume": [], + "mean_hu": [], + "median_hu": [], + "max_hu": [], + } + + ct = inference_pipeline.medical_volume.get_fdata() + + for i in range(1, num_lesions + 1): + tmp_mask = labelled_calc == i + + tmp_ct_vals = ct[tmp_mask] + + metrics["volume"].append( + len(tmp_ct_vals) * inference_pipeline.vol_per_pixel + ) + metrics["mean_hu"].append(np.mean(tmp_ct_vals)) + metrics["median_hu"].append(np.median(tmp_ct_vals)) + metrics["max_hu"].append(np.max(tmp_ct_vals)) + + # Volume of calcificaitons + calc_vol = np.sum(metrics["volume"]) + metrics["volume_total"] = calc_vol + + metrics["num_calc"] = len(metrics["volume"]) + + inference_pipeline.metrics = metrics + + return {} diff --git a/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium_visualization.py b/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..ef551475ccff769806793a0e4a3aa7a650df1137 --- /dev/null +++ b/Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium_visualization.py @@ -0,0 +1,119 @@ +import os + +import numpy as np + +from comp2comp.inference_class_base import InferenceClass + + +class AorticCalciumVisualizer(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + self.output_dir = inference_pipeline.output_dir + self.output_dir_images_organs = os.path.join(self.output_dir, "images/") + inference_pipeline.output_dir_images_organs = self.output_dir_images_organs + + if not os.path.exists(self.output_dir_images_organs): + os.makedirs(self.output_dir_images_organs) + + return {} + + +class AorticCalciumPrinter(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + metrics = inference_pipeline.metrics + + inference_pipeline.csv_output_dir = os.path.join( + inference_pipeline.output_dir, "metrics" + ) + os.makedirs(inference_pipeline.csv_output_dir, exist_ok=True) + + with open( + os.path.join(inference_pipeline.csv_output_dir, "aortic_calcification.csv"), + "w", + ) as f: + f.write("Volume (cm^3),Mean HU,Median HU,Max HU\n") + for vol, mean, median, max in zip( + metrics["volume"], + metrics["mean_hu"], + metrics["median_hu"], + metrics["max_hu"], + ): + f.write("{},{:.1f},{:.1f},{:.1f}\n".format(vol, mean, median, max)) + + with open( + os.path.join( + inference_pipeline.csv_output_dir, "aortic_calcification_total.csv" + ), + "w", + ) as f: + f.write("Total number,{}\n".format(metrics["num_calc"])) + f.write("Total volume (cm^3),{}\n".format(metrics["volume_total"])) + + distance = 25 + print("\n") + if metrics["num_calc"] == 0: + print("No aortic calcifications were found.") + else: + print("Statistics on aortic calcifications:") + print("{:<{}}{}".format("Total number:", distance, metrics["num_calc"])) + print( + "{:<{}}{:.3f}".format( + "Total volume (cm³):", distance, metrics["volume_total"] + ) + ) + print( + "{:<{}}{:.1f}+/-{:.1f}".format( + "Mean HU:", + distance, + np.mean(metrics["mean_hu"]), + np.std(metrics["mean_hu"]), + ) + ) + print( + "{:<{}}{:.1f}+/-{:.1f}".format( + "Median HU:", + distance, + np.mean(metrics["median_hu"]), + np.std(metrics["median_hu"]), + ) + ) + print( + "{:<{}}{:.1f}+/-{:.1f}".format( + "Max HU:", + distance, + np.mean(metrics["max_hu"]), + np.std(metrics["max_hu"]), + ) + ) + print( + "{:<{}}{:.3f}+/-{:.3f}".format( + "Mean volume (cm³):", + distance, + np.mean(metrics["volume"]), + np.std(metrics["volume"]), + ) + ) + print( + "{:<{}}{:.3f}".format( + "Median volume (cm³):", distance, np.median(metrics["volume"]) + ) + ) + print( + "{:<{}}{:.3f}".format( + "Max volume (cm³):", distance, np.max(metrics["volume"]) + ) + ) + print( + "{:<{}}{:.3f}".format( + "Min volume (cm³):", distance, np.min(metrics["volume"]) + ) + ) + + print("\n") + + return {} diff --git a/Comp2Comp-main/comp2comp/contrast_phase/contrast_inf.py b/Comp2Comp-main/comp2comp/contrast_phase/contrast_inf.py new file mode 100644 index 0000000000000000000000000000000000000000..4676d24fcc47038a48993012ec4f47265f64598e --- /dev/null +++ b/Comp2Comp-main/comp2comp/contrast_phase/contrast_inf.py @@ -0,0 +1,466 @@ +import argparse +import os +import pickle +import sys + +import nibabel as nib +import numpy as np +import scipy +import SimpleITK as sitk +from scipy import ndimage as ndi + + +def loadNiiToArray(path): + NiImg = nib.load(path) + array = np.array(NiImg.dataobj) + return array + + +def loadNiiWithSitk(path): + reader = sitk.ImageFileReader() + reader.SetImageIO("NiftiImageIO") + reader.SetFileName(path) + image = reader.Execute() + array = sitk.GetArrayFromImage(image) + return array + + +def loadNiiImageWithSitk(path): + reader = sitk.ImageFileReader() + reader.SetImageIO("NiftiImageIO") + reader.SetFileName(path) + image = reader.Execute() + # invert the image to be compatible with Nibabel + image = sitk.Flip(image, [False, True, False]) + return image + + +def keep_masked_values(arr, mask): + # Get the indices of the non-zero elements in the mask + mask_indices = np.nonzero(mask) + # Use the indices to select the corresponding elements from the array + masked_values = arr[mask_indices] + # Return the selected elements as a new array + return masked_values + + +def get_stats(arr): + # # Get the indices of the non-zero elements in the array + # nonzero_indices = np.nonzero(arr) + # # Use the indices to get the non-zero elements of the array + # nonzero_elements = arr[nonzero_indices] + + nonzero_elements = arr + + # Calculate the stats for the non-zero elements + max_val = np.max(nonzero_elements) + min_val = np.min(nonzero_elements) + mean_val = np.mean(nonzero_elements) + median_val = np.median(nonzero_elements) + std_val = np.std(nonzero_elements) + variance_val = np.var(nonzero_elements) + return max_val, min_val, mean_val, median_val, std_val, variance_val + + +def getMaskAnteriorAtrium(mask): + erasePreAtriumMask = mask.copy() + for sliceNum in range(mask.shape[-1]): + mask2D = mask[:, :, sliceNum] + itemindex = np.where(mask2D == 1) + if itemindex[0].size > 0: + row = itemindex[0][0] + erasePreAtriumMask[:, :, sliceNum][:row, :] = 1 + return erasePreAtriumMask + + +""" +Function from +https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays/46314485#46314485 +""" + + +def fill_hull(image): + points = np.transpose(np.where(image)) + hull = scipy.spatial.ConvexHull(points) + deln = scipy.spatial.Delaunay(points[hull.vertices]) + idx = np.stack(np.indices(image.shape), axis=-1) + out_idx = np.nonzero(deln.find_simplex(idx) + 1) + out_img = np.zeros(image.shape) + out_img[out_idx] = 1 + return out_img + + +def getClassBinaryMask(TSOutArray, classNum): + binaryMask = np.zeros(TSOutArray.shape) + binaryMask[TSOutArray == classNum] = 1 + return binaryMask + + +def loadNiftis(TSNiftiPath, imageNiftiPath): + TSArray = loadNiiToArray(TSNiftiPath) + scanArray = loadNiiToArray(imageNiftiPath) + return TSArray, scanArray + + +# function to select one slice from 3D volume of SimpleITK image +def selectSlice(scanImage, zslice): + size = list(scanImage.GetSize()) + size[2] = 0 + index = [0, 0, zslice] + + Extractor = sitk.ExtractImageFilter() + Extractor.SetSize(size) + Extractor.SetIndex(index) + + sliceImage = Extractor.Execute(scanImage) + return sliceImage + + +# function to apply windowing +def windowing(sliceImage, center=400, width=400): + windowMinimum = center - (width / 2) + windowMaximum = center + (width / 2) + img_255 = sitk.Cast( + sitk.IntensityWindowing( + sliceImage, + windowMinimum=-windowMinimum, + windowMaximum=windowMaximum, + outputMinimum=0.0, + outputMaximum=255.0, + ), + sitk.sitkUInt8, + ) + return img_255 + + +def selectSampleSlice(kidneyLMask, adRMask, scanImage): + # Get the middle slice of the kidney mask from where there is the first 1 value to the last 1 value + middleSlice = np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] + int( + ( + np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][-1] + - np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] + ) + / 2 + ) + # print("Middle slice: ", middleSlice) + # make middleSlice int + middleSlice = int(middleSlice) + # select one slice using simple itk + sliceImageK = selectSlice(scanImage, middleSlice) + + # Get the middle slice of the addrenal mask from where there is the first 1 value to the last 1 value + middleSlice = np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] + int( + ( + np.where(adRMask.sum(axis=(0, 1)) > 0)[0][-1] + - np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] + ) + / 2 + ) + # print("Middle slice: ", middleSlice) + # make middleSlice int + middleSlice = int(middleSlice) + # select one slice using simple itk + sliceImageA = selectSlice(scanImage, middleSlice) + + sliceImageK = windowing(sliceImageK) + sliceImageA = windowing(sliceImageA) + + return sliceImageK, sliceImageA + + +def getFeatures(TSArray, scanArray): + aortaMask = getClassBinaryMask(TSArray, 7) + IVCMask = getClassBinaryMask(TSArray, 8) + portalMask = getClassBinaryMask(TSArray, 9) + atriumMask = getClassBinaryMask(TSArray, 45) + kidneyLMask = getClassBinaryMask(TSArray, 3) + kidneyRMask = getClassBinaryMask(TSArray, 2) + adRMask = getClassBinaryMask(TSArray, 11) + + # Remove toraccic aorta adn IVC from aorta and IVC masks + anteriorAtriumMask = getMaskAnteriorAtrium(atriumMask) + aortaMask = aortaMask * (anteriorAtriumMask == 0) + IVCMask = IVCMask * (anteriorAtriumMask == 0) + + # Erode vessels to get only the center of the vessels + struct2 = np.ones((3, 3, 3)) + aortaMaskEroded = ndi.binary_erosion(aortaMask, structure=struct2).astype( + aortaMask.dtype + ) + IVCMaskEroded = ndi.binary_erosion(IVCMask, structure=struct2).astype(IVCMask.dtype) + + struct3 = np.ones((1, 1, 1)) + portalMaskEroded = ndi.binary_erosion(portalMask, structure=struct3).astype( + portalMask.dtype + ) + # If portalMaskEroded has less then 500 values, use the original portalMask + if np.count_nonzero(portalMaskEroded) < 500: + portalMaskEroded = portalMask + + # Get masked values from scan + aortaArray = keep_masked_values(scanArray, aortaMaskEroded) + IVCArray = keep_masked_values(scanArray, IVCMaskEroded) + portalArray = keep_masked_values(scanArray, portalMaskEroded) + kidneyLArray = keep_masked_values(scanArray, kidneyLMask) + kidneyRArray = keep_masked_values(scanArray, kidneyRMask) + + """Put this on a separate function and return only the pelvis arrays""" + # process the Renal Pelvis masks from the Kidney masks + # create the convex hull of the Left Kidney + kidneyLHull = fill_hull(kidneyLMask) + # exclude the Left Kidney mask from the Left Convex Hull + kidneyLHull = kidneyLHull * (kidneyLMask == 0) + # erode the kidneyHull to remove the edges + struct = np.ones((3, 3, 3)) + kidneyLHull = ndi.binary_erosion(kidneyLHull, structure=struct).astype( + kidneyLHull.dtype + ) + # keep the values of the scanArray that are in the Left Convex Hull + pelvisLArray = keep_masked_values(scanArray, kidneyLHull) + + # create the convex hull of the Right Kidney + kidneyRHull = fill_hull(kidneyRMask) + # exclude the Right Kidney mask from the Right Convex Hull + kidneyRHull = kidneyRHull * (kidneyRMask == 0) + # erode the kidneyHull to remove the edges + struct = np.ones((3, 3, 3)) + kidneyRHull = ndi.binary_erosion(kidneyRHull, structure=struct).astype( + kidneyRHull.dtype + ) + # keep the values of the scanArray that are in the Right Convex Hull + pelvisRArray = keep_masked_values(scanArray, kidneyRHull) + + # Get the stats + # Get the stats for the aortaArray + ( + aorta_max_val, + aorta_min_val, + aorta_mean_val, + aorta_median_val, + aorta_std_val, + aorta_variance_val, + ) = get_stats(aortaArray) + + # Get the stats for the IVCArray + ( + IVC_max_val, + IVC_min_val, + IVC_mean_val, + IVC_median_val, + IVC_std_val, + IVC_variance_val, + ) = get_stats(IVCArray) + + # Get the stats for the portalArray + ( + portal_max_val, + portal_min_val, + portal_mean_val, + portal_median_val, + portal_std_val, + portal_variance_val, + ) = get_stats(portalArray) + + # Get the stats for the kidneyLArray and kidneyRArray + ( + kidneyL_max_val, + kidneyL_min_val, + kidneyL_mean_val, + kidneyL_median_val, + kidneyL_std_val, + kidneyL_variance_val, + ) = get_stats(kidneyLArray) + ( + kidneyR_max_val, + kidneyR_min_val, + kidneyR_mean_val, + kidneyR_median_val, + kidneyR_std_val, + kidneyR_variance_val, + ) = get_stats(kidneyRArray) + + ( + pelvisL_max_val, + pelvisL_min_val, + pelvisL_mean_val, + pelvisL_median_val, + pelvisL_std_val, + pelvisL_variance_val, + ) = get_stats(pelvisLArray) + ( + pelvisR_max_val, + pelvisR_min_val, + pelvisR_mean_val, + pelvisR_median_val, + pelvisR_std_val, + pelvisR_variance_val, + ) = get_stats(pelvisRArray) + + # create three new columns for the decision tree + # aorta - porta, Max min and mean columns + aorta_porta_max = aorta_max_val - portal_max_val + aorta_porta_min = aorta_min_val - portal_min_val + aorta_porta_mean = aorta_mean_val - portal_mean_val + + # aorta - IVC, Max min and mean columns + aorta_IVC_max = aorta_max_val - IVC_max_val + aorta_IVC_min = aorta_min_val - IVC_min_val + aorta_IVC_mean = aorta_mean_val - IVC_mean_val + + # Save stats in CSV: + # Create a list to store the stats + stats = [] + # Add the stats for the aortaArray to the list + stats.extend( + [ + aorta_max_val, + aorta_min_val, + aorta_mean_val, + aorta_median_val, + aorta_std_val, + aorta_variance_val, + ] + ) + # Add the stats for the IVCArray to the list + stats.extend( + [ + IVC_max_val, + IVC_min_val, + IVC_mean_val, + IVC_median_val, + IVC_std_val, + IVC_variance_val, + ] + ) + # Add the stats for the portalArray to the list + stats.extend( + [ + portal_max_val, + portal_min_val, + portal_mean_val, + portal_median_val, + portal_std_val, + portal_variance_val, + ] + ) + # Add the stats for the kidneyLArray and kidneyRArray to the list + stats.extend( + [ + kidneyL_max_val, + kidneyL_min_val, + kidneyL_mean_val, + kidneyL_median_val, + kidneyL_std_val, + kidneyL_variance_val, + ] + ) + stats.extend( + [ + kidneyR_max_val, + kidneyR_min_val, + kidneyR_mean_val, + kidneyR_median_val, + kidneyR_std_val, + kidneyR_variance_val, + ] + ) + # Add the stats for the kidneyLHull and kidneyRHull to the list + stats.extend( + [ + pelvisL_max_val, + pelvisL_min_val, + pelvisL_mean_val, + pelvisL_median_val, + pelvisL_std_val, + pelvisL_variance_val, + ] + ) + stats.extend( + [ + pelvisR_max_val, + pelvisR_min_val, + pelvisR_mean_val, + pelvisR_median_val, + pelvisR_std_val, + pelvisR_variance_val, + ] + ) + + stats.extend( + [ + aorta_porta_max, + aorta_porta_min, + aorta_porta_mean, + aorta_IVC_max, + aorta_IVC_min, + aorta_IVC_mean, + ] + ) + + return stats, kidneyLMask, adRMask + + +def loadModel(): + c2cPath = os.path.dirname(sys.path[0]) + filename = os.path.join(c2cPath, "comp2comp", "contrast_phase", "xgboost.pkl") + model = pickle.load(open(filename, "rb")) + + return model + + +def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False): + TS_array, image_array = loadNiftis(TS_path, scan_path) + model = loadModel() + # TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path) + featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array) + y_pred = model.predict([featureArray]) + + if y_pred == 0: + pred_phase = "non-contrast" + if y_pred == 1: + pred_phase = "arterial" + if y_pred == 2: + pred_phase = "venous" + if y_pred == 3: + pred_phase = "delayed" + + output_path_metrics = os.path.join(outputPath, "metrics") + if not os.path.exists(output_path_metrics): + os.makedirs(output_path_metrics) + outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt") + with open(outputTxt, "w") as text_file: + text_file.write(pred_phase) + print(pred_phase) + + output_path_images = os.path.join(outputPath, "images") + if not os.path.exists(output_path_images): + os.makedirs(output_path_images) + scanImage = loadNiiImageWithSitk(scan_path) + sliceImageK, sliceImageA = selectSampleSlice(kidneyLMask, adRMask, scanImage) + outJpgK = os.path.join(output_path_images, "sampleSliceKidney.png") + sitk.WriteImage(sliceImageK, outJpgK) + outJpgA = os.path.join(output_path_images, "sampleSliceAdrenal.png") + sitk.WriteImage(sliceImageA, outJpgA) + + +if __name__ == "__main__": + # parse arguments optional + parser = argparse.ArgumentParser() + parser.add_argument("--TS_path", type=str, required=True, help="Input image") + parser.add_argument("--scan_path", type=str, required=True, help="Input image") + parser.add_argument( + "--output_dir", + type=str, + required=False, + help="Output .txt prediction", + default=None, + ) + parser.add_argument( + "--save_sample", + type=bool, + required=False, + help="Save jpeg sample ", + default=False, + ) + args = parser.parse_args() + predict_phase(args.TS_path, args.scan_path, args.output_dir, args.save_sample) diff --git a/Comp2Comp-main/comp2comp/contrast_phase/contrast_phase.py b/Comp2Comp-main/comp2comp/contrast_phase/contrast_phase.py new file mode 100644 index 0000000000000000000000000000000000000000..bc144a2e5331971d779c0adba2bd1f5f42357693 --- /dev/null +++ b/Comp2Comp-main/comp2comp/contrast_phase/contrast_phase.py @@ -0,0 +1,116 @@ +import os +from pathlib import Path +from time import time +from typing import Union + +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) + +from comp2comp.contrast_phase.contrast_inf import predict_phase +from comp2comp.inference_class_base import InferenceClass + + +class ContrastPhaseDetection(InferenceClass): + """Contrast Phase Detection.""" + + def __init__(self, input_path): + super().__init__() + self.input_path = input_path + + def __call__(self, inference_pipeline): + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + self.model_dir = inference_pipeline.model_dir + + seg, img = self.run_segmentation( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + self.output_dir_segmentations + "s01.nii.gz", + inference_pipeline.model_dir, + ) + + # segArray, imgArray = self.convertNibToNumpy(seg, img) + + imgNiftiPath = os.path.join( + self.output_dir_segmentations, "converted_dcm.nii.gz" + ) + segNiftPath = os.path.join(self.output_dir_segmentations, "s01.nii.gz") + + predict_phase(segNiftPath, imgNiftiPath, outputPath=self.output_dir) + + return {} + + def run_segmentation( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting...") + st = time() + os.environ["SCRATCH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [251] + + setup_nnunet() + for task_id in [251]: + download_pretrained_weights(task_id) + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + img, seg = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag=None, + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=False, + test=0, + ) + end = time() + + # Log total time for spine segmentation + print(f"Total time for segmentation: {end-st:.2f}s.") + + return seg, img + + def convertNibToNumpy(self, TSNib, ImageNib): + """Convert nifti to numpy array. + + Args: + TSNib (nibabel.nifti1.Nifti1Image): TotalSegmentator output. + ImageNib (nibabel.nifti1.Nifti1Image): Input image. + + Returns: + numpy.ndarray: TotalSegmentator output. + numpy.ndarray: Input image. + """ + TS_array = TSNib.get_fdata() + img_array = ImageNib.get_fdata() + return TS_array, img_array diff --git a/Comp2Comp-main/comp2comp/contrast_phase/xgboost.pkl b/Comp2Comp-main/comp2comp/contrast_phase/xgboost.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8edccb5f57aef8fa016473a3ad94ce0462be3d5b --- /dev/null +++ b/Comp2Comp-main/comp2comp/contrast_phase/xgboost.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:070af05754cc9541e924c0ede654b1c40a01b9240f14483af5284ae0b92d4169 +size 422989 diff --git a/Comp2Comp-main/comp2comp/hip/hip.py b/Comp2Comp-main/comp2comp/hip/hip.py new file mode 100644 index 0000000000000000000000000000000000000000..692bda4951694c6ec648a11d5ae1517a0a20a461 --- /dev/null +++ b/Comp2Comp-main/comp2comp/hip/hip.py @@ -0,0 +1,301 @@ +""" +@author: louisblankemeier +""" + +import os +from pathlib import Path +from time import time +from typing import Union + +import pandas as pd +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) + +from comp2comp.hip import hip_utils +from comp2comp.hip.hip_visualization import ( + hip_report_visualizer, + hip_roi_visualizer, +) +from comp2comp.inference_class_base import InferenceClass +from comp2comp.models.models import Models + + +class HipSegmentation(InferenceClass): + """Spine segmentation.""" + + def __init__(self, model_name): + super().__init__() + self.model_name = model_name + self.model = Models.model_from_name(model_name) + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + seg, mv = self.hip_seg( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + self.output_dir_segmentations + "hip.nii.gz", + inference_pipeline.model_dir, + ) + + inference_pipeline.model = self.model + inference_pipeline.segmentation = seg + inference_pipeline.medical_volume = mv + + return {} + + def hip_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run spine segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting hip...") + st = time() + os.environ["SCRATCH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [254] + + if self.model_name == "ts_hip": + setup_nnunet() + download_pretrained_weights(task_id[0]) + else: + raise ValueError("Invalid model name.") + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + img, seg = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag=None, + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=False, + test=0, + ) + end = time() + + # Log total time for hip segmentation + print(f"Total time for hip segmentation: {end-st:.2f}s.") + + return seg, img + + +class HipComputeROIs(InferenceClass): + def __init__(self, hip_model): + super().__init__() + self.hip_model_name = hip_model + self.hip_model_type = Models.model_from_name(self.hip_model_name) + + def __call__(self, inference_pipeline): + segmentation = inference_pipeline.segmentation + medical_volume = inference_pipeline.medical_volume + + model = inference_pipeline.model + images_folder = os.path.join(inference_pipeline.output_dir, "dev") + results_dict = hip_utils.compute_rois( + medical_volume, segmentation, model, images_folder + ) + inference_pipeline.femur_results_dict = results_dict + return {} + + +class HipMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + metrics_output_dir = os.path.join(inference_pipeline.output_dir, "metrics") + if not os.path.exists(metrics_output_dir): + os.makedirs(metrics_output_dir) + results_dict = inference_pipeline.femur_results_dict + left_head_hu = results_dict["left_head"]["hu"] + right_head_hu = results_dict["right_head"]["hu"] + left_intertrochanter_hu = results_dict["left_intertrochanter"]["hu"] + right_intertrochanter_hu = results_dict["right_intertrochanter"]["hu"] + left_neck_hu = results_dict["left_neck"]["hu"] + right_neck_hu = results_dict["right_neck"]["hu"] + # save to csv + df = pd.DataFrame( + { + "Left Head (HU)": [left_head_hu], + "Right Head (HU)": [right_head_hu], + "Left Intertrochanter (HU)": [left_intertrochanter_hu], + "Right Intertrochanter (HU)": [right_intertrochanter_hu], + "Left Neck (HU)": [left_neck_hu], + "Right Neck (HU)": [right_neck_hu], + } + ) + df.to_csv(os.path.join(metrics_output_dir, "hip_metrics.csv"), index=False) + return {} + + +class HipVisualizer(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + medical_volume = inference_pipeline.medical_volume + + left_head_roi = inference_pipeline.femur_results_dict["left_head"]["roi"] + left_head_centroid = inference_pipeline.femur_results_dict["left_head"][ + "centroid" + ] + left_head_hu = inference_pipeline.femur_results_dict["left_head"]["hu"] + + left_intertrochanter_roi = inference_pipeline.femur_results_dict[ + "left_intertrochanter" + ]["roi"] + left_intertrochanter_centroid = inference_pipeline.femur_results_dict[ + "left_intertrochanter" + ]["centroid"] + left_intertrochanter_hu = inference_pipeline.femur_results_dict[ + "left_intertrochanter" + ]["hu"] + + left_neck_roi = inference_pipeline.femur_results_dict["left_neck"]["roi"] + left_neck_centroid = inference_pipeline.femur_results_dict["left_neck"][ + "centroid" + ] + left_neck_hu = inference_pipeline.femur_results_dict["left_neck"]["hu"] + + right_head_roi = inference_pipeline.femur_results_dict["right_head"]["roi"] + right_head_centroid = inference_pipeline.femur_results_dict["right_head"][ + "centroid" + ] + right_head_hu = inference_pipeline.femur_results_dict["right_head"]["hu"] + + right_intertrochanter_roi = inference_pipeline.femur_results_dict[ + "right_intertrochanter" + ]["roi"] + right_intertrochanter_centroid = inference_pipeline.femur_results_dict[ + "right_intertrochanter" + ]["centroid"] + right_intertrochanter_hu = inference_pipeline.femur_results_dict[ + "right_intertrochanter" + ]["hu"] + + right_neck_roi = inference_pipeline.femur_results_dict["right_neck"]["roi"] + right_neck_centroid = inference_pipeline.femur_results_dict["right_neck"][ + "centroid" + ] + right_neck_hu = inference_pipeline.femur_results_dict["right_neck"]["hu"] + + output_dir = inference_pipeline.output_dir + images_output_dir = os.path.join(output_dir, "images") + if not os.path.exists(images_output_dir): + os.makedirs(images_output_dir) + hip_roi_visualizer( + medical_volume, + left_head_roi, + left_head_centroid, + left_head_hu, + images_output_dir, + "left_head", + ) + hip_roi_visualizer( + medical_volume, + left_intertrochanter_roi, + left_intertrochanter_centroid, + left_intertrochanter_hu, + images_output_dir, + "left_intertrochanter", + ) + hip_roi_visualizer( + medical_volume, + left_neck_roi, + left_neck_centroid, + left_neck_hu, + images_output_dir, + "left_neck", + ) + hip_roi_visualizer( + medical_volume, + right_head_roi, + right_head_centroid, + right_head_hu, + images_output_dir, + "right_head", + ) + hip_roi_visualizer( + medical_volume, + right_intertrochanter_roi, + right_intertrochanter_centroid, + right_intertrochanter_hu, + images_output_dir, + "right_intertrochanter", + ) + hip_roi_visualizer( + medical_volume, + right_neck_roi, + right_neck_centroid, + right_neck_hu, + images_output_dir, + "right_neck", + ) + hip_report_visualizer( + medical_volume.get_fdata(), + left_head_roi + right_head_roi, + [left_head_centroid, right_head_centroid], + images_output_dir, + "head", + { + "Left Head HU": round(left_head_hu), + "Right Head HU": round(right_head_hu), + }, + ) + hip_report_visualizer( + medical_volume.get_fdata(), + left_intertrochanter_roi + right_intertrochanter_roi, + [left_intertrochanter_centroid, right_intertrochanter_centroid], + images_output_dir, + "intertrochanter", + { + "Left Intertrochanter HU": round(left_intertrochanter_hu), + "Right Intertrochanter HU": round(right_intertrochanter_hu), + }, + ) + hip_report_visualizer( + medical_volume.get_fdata(), + left_neck_roi + right_neck_roi, + [left_neck_centroid, right_neck_centroid], + images_output_dir, + "neck", + { + "Left Neck HU": round(left_neck_hu), + "Right Neck HU": round(right_neck_hu), + }, + ) + return {} diff --git a/Comp2Comp-main/comp2comp/hip/hip_utils.py b/Comp2Comp-main/comp2comp/hip/hip_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..60a776d62cafa27743672dc3d964b6bfb542384d --- /dev/null +++ b/Comp2Comp-main/comp2comp/hip/hip_utils.py @@ -0,0 +1,362 @@ +""" +@author: louisblankemeier +""" + +import math +import os +import shutil + +import cv2 +import nibabel as nib +import numpy as np +import scipy.ndimage as ndi +from scipy.ndimage import zoom +from skimage.morphology import ball, binary_erosion + +from comp2comp.hip.hip_visualization import method_visualizer + + +def compute_rois(medical_volume, segmentation, model, output_dir, save=False): + left_femur_mask = segmentation.get_fdata() == model.categories["femur_left"] + left_femur_mask = left_femur_mask.astype(np.uint8) + right_femur_mask = segmentation.get_fdata() == model.categories["femur_right"] + right_femur_mask = right_femur_mask.astype(np.uint8) + left_head_roi, left_head_centroid, left_head_hu = get_femural_head_roi( + left_femur_mask, medical_volume, output_dir, "left_head" + ) + right_head_roi, right_head_centroid, right_head_hu = get_femural_head_roi( + right_femur_mask, medical_volume, output_dir, "right_head" + ) + ( + left_intertrochanter_roi, + left_intertrochanter_centroid, + left_intertrochanter_hu, + ) = get_femural_head_roi( + left_femur_mask, medical_volume, output_dir, "left_intertrochanter" + ) + ( + right_intertrochanter_roi, + right_intertrochanter_centroid, + right_intertrochanter_hu, + ) = get_femural_head_roi( + right_femur_mask, medical_volume, output_dir, "right_intertrochanter" + ) + ( + left_neck_roi, + left_neck_centroid, + left_neck_hu, + ) = get_femural_neck_roi( + left_femur_mask, + medical_volume, + left_intertrochanter_roi, + left_intertrochanter_centroid, + left_head_roi, + left_head_centroid, + output_dir, + ) + ( + right_neck_roi, + right_neck_centroid, + right_neck_hu, + ) = get_femural_neck_roi( + right_femur_mask, + medical_volume, + right_intertrochanter_roi, + right_intertrochanter_centroid, + right_head_roi, + right_head_centroid, + output_dir, + ) + combined_roi = ( + left_head_roi + + (right_head_roi) # * 2) + + (left_intertrochanter_roi) # * 3) + + (right_intertrochanter_roi) # * 4) + + (left_neck_roi) # * 5) + + (right_neck_roi) # * 6) + ) + + if save: + # make roi directory if it doesn't exist + parent_output_dir = os.path.dirname(output_dir) + roi_output_dir = os.path.join(parent_output_dir, "rois") + if not os.path.exists(roi_output_dir): + os.makedirs(roi_output_dir) + + # Convert left ROI to NIfTI + left_roi_nifti = nib.Nifti1Image(combined_roi, medical_volume.affine) + left_roi_path = os.path.join(roi_output_dir, "roi.nii.gz") + nib.save(left_roi_nifti, left_roi_path) + shutil.copy( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "tunnelvision.ipynb", + ), + parent_output_dir, + ) + + return { + "left_head": { + "roi": left_head_roi, + "centroid": left_head_centroid, + "hu": left_head_hu, + }, + "right_head": { + "roi": right_head_roi, + "centroid": right_head_centroid, + "hu": right_head_hu, + }, + "left_intertrochanter": { + "roi": left_intertrochanter_roi, + "centroid": left_intertrochanter_centroid, + "hu": left_intertrochanter_hu, + }, + "right_intertrochanter": { + "roi": right_intertrochanter_roi, + "centroid": right_intertrochanter_centroid, + "hu": right_intertrochanter_hu, + }, + "left_neck": { + "roi": left_neck_roi, + "centroid": left_neck_centroid, + "hu": left_neck_hu, + }, + "right_neck": { + "roi": right_neck_roi, + "centroid": right_neck_centroid, + "hu": right_neck_hu, + }, + } + + +def get_femural_head_roi( + femur_mask, + medical_volume, + output_dir, + anatomy, + visualize_method=False, + min_pixel_count=20, +): + top = np.where(femur_mask.sum(axis=(0, 1)) != 0)[0].max() + top_mask = femur_mask[:, :, top] + + print(f"======== Computing {anatomy} femur ROIs ========") + + while True: + labeled, num_features = ndi.label(top_mask) + + component_sizes = np.bincount(labeled.ravel()) + valid_components = np.where(component_sizes >= min_pixel_count)[0][1:] + + if len(valid_components) == 2: + break + + top -= 1 + if top < 0: + print("Two connected components not found in the femur mask.") + break + top_mask = femur_mask[:, :, top] + + if len(valid_components) == 2: + # Find the center of mass for each connected component + center_of_mass_1 = list( + ndi.center_of_mass(top_mask, labeled, valid_components[0]) + ) + center_of_mass_2 = list( + ndi.center_of_mass(top_mask, labeled, valid_components[1]) + ) + + # Assign left_center_of_mass to be the center of mass with lowest value in the first dimension + if center_of_mass_1[0] < center_of_mass_2[0]: + left_center_of_mass = center_of_mass_1 + right_center_of_mass = center_of_mass_2 + else: + left_center_of_mass = center_of_mass_2 + right_center_of_mass = center_of_mass_1 + + print(f"Left center of mass: {left_center_of_mass}") + print(f"Right center of mass: {right_center_of_mass}") + + if anatomy == "left_intertrochanter" or anatomy == "right_head": + center_of_mass = left_center_of_mass + elif anatomy == "right_intertrochanter" or anatomy == "left_head": + center_of_mass = right_center_of_mass + + coronal_slice = femur_mask[:, round(center_of_mass[1]), :] + coronal_image = medical_volume.get_fdata()[:, round(center_of_mass[1]), :] + sagittal_slice = femur_mask[round(center_of_mass[0]), :, :] + sagittal_image = medical_volume.get_fdata()[round(center_of_mass[0]), :, :] + + zooms = medical_volume.header.get_zooms() + zoom_factor = zooms[2] / zooms[1] + + coronal_slice = zoom(coronal_slice, (1, zoom_factor), order=1).round() + coronal_image = zoom(coronal_image, (1, zoom_factor), order=3).round() + sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() + + centroid = [round(center_of_mass[0]), 0, 0] + + print(f"Starting centroid: {centroid}") + + for _ in range(3): + sagittal_slice = femur_mask[centroid[0], :, :] + sagittal_slice = zoom(sagittal_slice, (1, zoom_factor), order=1).round() + centroid[1], centroid[2], radius_sagittal = inscribe_sagittal( + sagittal_slice, zoom_factor + ) + + print(f"Centroid after inscribe sagittal: {centroid}") + + axial_slice = femur_mask[:, :, centroid[2]] + if anatomy == "left_intertrochanter" or anatomy == "right_head": + axial_slice[round(right_center_of_mass[0]) :, :] = 0 + elif anatomy == "right_intertrochanter" or anatomy == "left_head": + axial_slice[: round(left_center_of_mass[0]), :] = 0 + centroid[0], centroid[1], radius_axial = inscribe_axial(axial_slice) + + print(f"Centroid after inscribe axial: {centroid}") + + axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])] + sagittal_image = medical_volume.get_fdata()[round(centroid[0]), :, :] + sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() + + if visualize_method: + method_visualizer( + sagittal_image, + axial_image, + axial_slice, + sagittal_slice, + [centroid[2], centroid[1]], + radius_sagittal, + [centroid[1], centroid[0]], + radius_axial, + output_dir, + anatomy, + ) + + roi = compute_hip_roi(medical_volume, centroid, radius_sagittal, radius_axial) + + # selem = ndi.generate_binary_structure(3, 1) + selem = ball(3) + femur_mask_eroded = binary_erosion(femur_mask, selem) + roi = roi * femur_mask_eroded + roi_eroded = roi.astype(np.uint8) + + hu = get_mean_roi_hu(medical_volume, roi_eroded) + + return (roi_eroded, centroid, hu) + + +def get_femural_neck_roi( + femur_mask, + medical_volume, + intertrochanter_roi, + intertrochanter_centroid, + head_roi, + head_centroid, + output_dir, +): + zooms = medical_volume.header.get_zooms() + + direction_vector = np.array(head_centroid) - np.array(intertrochanter_centroid) + unit_direction_vector = direction_vector / np.linalg.norm(direction_vector) + + z, y, x = np.where(intertrochanter_roi) + intertrochanter_points = np.column_stack((z, y, x)) + t_start = np.dot( + intertrochanter_points - intertrochanter_centroid, unit_direction_vector + ).max() + + z, y, x = np.where(head_roi) + head_points = np.column_stack((z, y, x)) + t_end = ( + np.linalg.norm(direction_vector) + + np.dot(head_points - head_centroid, unit_direction_vector).min() + ) + + z, y, x = np.indices(femur_mask.shape) + coordinates = np.stack((z, y, x), axis=-1) + + distance_to_line_origin = np.dot( + coordinates - intertrochanter_centroid, unit_direction_vector + ) + + coordinates_zoomed = coordinates * zooms + intertrochanter_centroid_zoomed = np.array(intertrochanter_centroid) * zooms + unit_direction_vector_zoomed = unit_direction_vector * zooms + + distance_to_line = np.linalg.norm( + np.cross( + coordinates_zoomed - intertrochanter_centroid_zoomed, + coordinates_zoomed + - (intertrochanter_centroid_zoomed + unit_direction_vector_zoomed), + ), + axis=-1, + ) / np.linalg.norm(unit_direction_vector_zoomed) + + cylinder_radius = 10 + + cylinder_mask = ( + (distance_to_line <= cylinder_radius) + & (distance_to_line_origin >= t_start) + & (distance_to_line_origin <= t_end) + ) + + # selem = ndi.generate_binary_structure(3, 1) + selem = ball(3) + femur_mask_eroded = binary_erosion(femur_mask, selem) + roi = cylinder_mask * femur_mask_eroded + neck_roi = roi.astype(np.uint8) + + hu = get_mean_roi_hu(medical_volume, neck_roi) + + centroid = list( + intertrochanter_centroid + unit_direction_vector * (t_start + t_end) / 2 + ) + centroid = [round(x) for x in centroid] + + return neck_roi, centroid, hu + + +def compute_hip_roi(img, centroid, radius_sagittal, radius_axial): + pixel_spacing = img.header.get_zooms() + length_i = radius_axial * 0.75 / pixel_spacing[0] + length_j = radius_axial * 0.75 / pixel_spacing[1] + length_k = radius_sagittal * 0.75 / pixel_spacing[2] + + roi = np.zeros(img.get_fdata().shape, dtype=np.uint8) + i_lower = math.floor(centroid[0] - length_i) + j_lower = math.floor(centroid[1] - length_j) + k_lower = math.floor(centroid[2] - length_k) + for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1): + for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1): + for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1): + if (i - centroid[0]) ** 2 / length_i**2 + ( + j - centroid[1] + ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1: + roi[i, j, k] = 1 + return roi + + +def inscribe_axial(axial_mask): + dist_map = cv2.distanceTransform(axial_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) + _, radius_axial, _, center_axial = cv2.minMaxLoc(dist_map) + center_axial = list(center_axial) + left_right_center = round(center_axial[1]) + posterior_anterior_center = round(center_axial[0]) + return left_right_center, posterior_anterior_center, radius_axial + + +def inscribe_sagittal(sagittal_mask, zoom_factor): + dist_map = cv2.distanceTransform(sagittal_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) + _, radius_sagittal, _, center_sagittal = cv2.minMaxLoc(dist_map) + center_sagittal = list(center_sagittal) + posterior_anterior_center = round(center_sagittal[1]) + inferior_superior_center = round(center_sagittal[0]) + inferior_superior_center = round(inferior_superior_center / zoom_factor) + return posterior_anterior_center, inferior_superior_center, radius_sagittal + + +def get_mean_roi_hu(medical_volume, roi): + masked_medical_volume = medical_volume.get_fdata() * roi + return np.mean(masked_medical_volume[masked_medical_volume != 0]) diff --git a/Comp2Comp-main/comp2comp/hip/hip_visualization.py b/Comp2Comp-main/comp2comp/hip/hip_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..86e8631e1950d3e2a32dfcf9ecdb3208489d6828 --- /dev/null +++ b/Comp2Comp-main/comp2comp/hip/hip_visualization.py @@ -0,0 +1,171 @@ +""" +@author: louisblankemeier +""" + +import os + +import numpy as np +from scipy.ndimage import zoom + +from comp2comp.visualization.detectron_visualizer import Visualizer +from comp2comp.visualization.linear_planar_reformation import ( + linear_planar_reformation, +) + + +def method_visualizer( + sagittal_image, + axial_image, + axial_slice, + sagittal_slice, + center_sagittal, + radius_sagittal, + center_axial, + radius_axial, + output_dir, + anatomy, +): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + axial_image = np.clip(axial_image, -300, 1800) + axial_image = normalize_img(axial_image) * 255.0 + + sagittal_image = np.clip(sagittal_image, -300, 1800) + sagittal_image = normalize_img(sagittal_image) * 255.0 + + sagittal_image = sagittal_image.reshape( + (sagittal_image.shape[0], sagittal_image.shape[1], 1) + ) + img_rgb = np.tile(sagittal_image, (1, 1, 3)) + vis = Visualizer(img_rgb) + vis.draw_circle( + circle_coord=center_sagittal, color=[0, 1, 0], radius=radius_sagittal + ) + vis.draw_binary_mask(sagittal_slice) + + vis_obj = vis.get_output() + vis_obj.save(os.path.join(output_dir, f"{anatomy}_sagittal_method.png")) + + axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1)) + img_rgb = np.tile(axial_image, (1, 1, 3)) + vis = Visualizer(img_rgb) + vis.draw_circle(circle_coord=center_axial, color=[0, 1, 0], radius=radius_axial) + vis.draw_binary_mask(axial_slice) + + vis_obj = vis.get_output() + vis_obj.save(os.path.join(output_dir, f"{anatomy}_axial_method.png")) + + +def hip_roi_visualizer( + medical_volume, + roi, + centroid, + hu, + output_dir, + anatomy, +): + zooms = medical_volume.header.get_zooms() + zoom_factor = zooms[2] / zooms[1] + + sagittal_image = medical_volume.get_fdata()[centroid[0], :, :] + sagittal_roi = roi[centroid[0], :, :] + + sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=1).round() + sagittal_roi = zoom(sagittal_roi, (1, zoom_factor), order=3).round() + sagittal_image = np.flip(sagittal_image.T) + sagittal_roi = np.flip(sagittal_roi.T) + + axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])] + axial_roi = roi[:, :, round(centroid[2])] + + axial_image = np.flip(axial_image.T) + axial_roi = np.flip(axial_roi.T) + + _ROI_COLOR = np.array([1.000, 0.340, 0.200]) + + sagittal_image = np.clip(sagittal_image, -300, 1800) + sagittal_image = normalize_img(sagittal_image) * 255.0 + sagittal_image = sagittal_image.reshape( + (sagittal_image.shape[0], sagittal_image.shape[1], 1) + ) + img_rgb = np.tile(sagittal_image, (1, 1, 3)) + vis = Visualizer(img_rgb) + vis.draw_binary_mask( + sagittal_roi, + color=_ROI_COLOR, + edge_color=_ROI_COLOR, + alpha=0.0, + area_threshold=0, + ) + vis.draw_text( + text=f"Mean HU: {round(hu)}", + position=(412, 10), + color=_ROI_COLOR, + font_size=9, + horizontal_alignment="left", + ) + vis_obj = vis.get_output() + vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_sagittal.png")) + + """ + axial_image = np.clip(axial_image, -300, 1800) + axial_image = normalize_img(axial_image) * 255.0 + axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1)) + img_rgb = np.tile(axial_image, (1, 1, 3)) + vis = Visualizer(img_rgb) + vis.draw_binary_mask( + axial_roi, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0 + ) + vis.draw_text( + text=f"Mean HU: {round(hu)}", + position=(412, 10), + color=_ROI_COLOR, + font_size=9, + horizontal_alignment="left", + ) + vis_obj = vis.get_output() + vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_axial.png")) + """ + + +def hip_report_visualizer(medical_volume, roi, centroids, output_dir, anatomy, labels): + _ROI_COLOR = np.array([1.000, 0.340, 0.200]) + image, mask = linear_planar_reformation( + medical_volume, roi, centroids, dimension="axial" + ) + # add 3rd dim to image + image = np.flip(image.T) + mask = np.flip(mask.T) + mask[mask > 1] = 1 + # mask = np.expand_dims(mask, axis=2) + image = np.expand_dims(image, axis=2) + image = np.clip(image, -300, 1800) + image = normalize_img(image) * 255.0 + img_rgb = np.tile(image, (1, 1, 3)) + vis = Visualizer(img_rgb) + vis.draw_binary_mask( + mask, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0 + ) + pos_idx = 0 + for key, value in labels.items(): + vis.draw_text( + text=f"{key}: {value}", + position=(310, 10 + pos_idx * 17), + color=_ROI_COLOR, + font_size=9, + horizontal_alignment="left", + ) + pos_idx += 1 + vis_obj = vis.get_output() + vis_obj.save(os.path.join(output_dir, f"{anatomy}_report_axial.png")) + + +def normalize_img(img: np.ndarray) -> np.ndarray: + """Normalize the image. + Args: + img (np.ndarray): Input image. + Returns: + np.ndarray: Normalized image. + """ + return (img - img.min()) / (img.max() - img.min()) diff --git a/Comp2Comp-main/comp2comp/hip/tunnelvision.ipynb b/Comp2Comp-main/comp2comp/hip/tunnelvision.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4db9f56a9e1a130cb72a63cf142bb02b4dceea0d --- /dev/null +++ b/Comp2Comp-main/comp2comp/hip/tunnelvision.ipynb @@ -0,0 +1,73 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import voxel as vx\n", + "import tunnelvision as tv\n", + "import numpy as np\n", + "\n", + "mv = vx.load(\"./segmentations/converted_dcm.nii.gz\")\n", + "mv = mv.reformat((\"LR\", \"PA\", \"IS\"))\n", + "np_mv = mv.A\n", + "np_mv = np_mv.astype(np.int32)\n", + "np_mv = np.expand_dims(np_mv, axis=0)\n", + "np_mv = np.expand_dims(np_mv, axis=4)\n", + "\n", + "seg = vx.load(\"./rois/roi.nii.gz\")\n", + "np_seg = seg.A\n", + "np_seg_dim = seg.A\n", + "np_seg = np_seg.astype(np.int32)\n", + "np_seg = np.expand_dims(np_seg, axis=0)\n", + "np_seg = np.expand_dims(np_seg, axis=4)\n", + "\n", + "hip_seg = vx.load(\"./segmentations/hip.nii.gz\")\n", + "hip_seg = hip_seg.reformat((\"LR\", \"PA\", \"IS\"))\n", + "np_hip_seg = hip_seg.A.astype(int)\n", + "# set values not equal to 88 or 89 to 0\n", + "np_hip_seg[(np_hip_seg != 88) & (np_hip_seg != 89)] = 0\n", + "np_hip_seg[np_hip_seg != 0] = np_hip_seg[np_hip_seg != 0] + 4\n", + "np_hip_seg[np_seg_dim != 0] = 0\n", + "np_hip_seg = np_hip_seg.astype(np.int32)\n", + "np_hip_seg = np.expand_dims(np_hip_seg, axis=0)\n", + "np_hip_seg = np.expand_dims(np_hip_seg, axis=4)\n", + "\n", + "ax = tv.Axes(figsize=(512, 512))\n", + "ax.imshow(np_mv)\n", + "ax.imshow(np_seg, cmap=\"seg\")\n", + "ax.imshow(np_hip_seg, cmap=\"seg\")\n", + "ax.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.16 ('c2c_env')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "62fd47c2f495fb43260e4f88a1d5487d18d4c091bac4d4df4eca96cade9f1e23" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Comp2Comp-main/comp2comp/inference_class_base.py b/Comp2Comp-main/comp2comp/inference_class_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c42e1b19374fa4d1fe466db3d1a40ce9e8f72bd7 --- /dev/null +++ b/Comp2Comp-main/comp2comp/inference_class_base.py @@ -0,0 +1,18 @@ +""" +@author: louisblankemeier +""" + +from typing import Dict + + +class InferenceClass: + """Base class for inference classes.""" + + def __init__(self): + pass + + def __call__(self) -> Dict: + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ diff --git a/Comp2Comp-main/comp2comp/inference_pipeline.py b/Comp2Comp-main/comp2comp/inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..97f4147fbb7951729ef0f31e40ef323e8ee179fa --- /dev/null +++ b/Comp2Comp-main/comp2comp/inference_pipeline.py @@ -0,0 +1,102 @@ +""" +@author: louisblankemeier +""" + +import inspect +import os +from typing import Dict, List + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.io.io import DicomLoader, NiftiSaver + + +class InferencePipeline(InferenceClass): + """Inference pipeline.""" + + def __init__(self, inference_classes: List = None, config: Dict = None): + self.config = config + # assign values from config to attributes + if self.config is not None: + for key, value in self.config.items(): + setattr(self, key, value) + + self.inference_classes = inference_classes + + def __call__(self, inference_pipeline=None, **kwargs): + # print out the class names for each inference class + print("") + print("Inference pipeline:") + for i, inference_class in enumerate(self.inference_classes): + print(f"({i + 1}) {inference_class.__repr__()}") + print("") + + print("Starting inference pipeline.\n") + + if inference_pipeline: + for key, value in kwargs.items(): + setattr(inference_pipeline, key, value) + else: + for key, value in kwargs.items(): + setattr(self, key, value) + + output = {} + for inference_class in self.inference_classes: + function_keys = set(inspect.signature(inference_class).parameters.keys()) + function_keys.remove("inference_pipeline") + + if "kwargs" in function_keys: + function_keys.remove("kwargs") + + assert function_keys == set( + output.keys() + ), "Input to inference class, {}, does not have the correct parameters".format( + inference_class.__repr__() + ) + + print( + "Running {} with input keys {}".format( + inference_class.__repr__(), + inspect.signature(inference_class).parameters.keys(), + ) + ) + + if inference_pipeline: + output = inference_class( + inference_pipeline=inference_pipeline, **output + ) + else: + output = inference_class(inference_pipeline=self, **output) + + # if not the last inference class, check that the output keys are correct + if inference_class != self.inference_classes[-1]: + print( + "Finished {} with output keys {}\n".format( + inference_class.__repr__(), output.keys() + ) + ) + + print("Inference pipeline finished.\n") + + return output + + +if __name__ == "__main__": + """Example usage of InferencePipeline.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--dicom_dir", type=str, required=True) + args = parser.parse_args() + + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_file_path = os.path.join(output_dir, "test.nii.gz") + + pipeline = InferencePipeline( + [DicomLoader(args.dicom_dir), NiftiSaver()], + config={"output_dir": output_file_path}, + ) + pipeline() + + print("Done.") diff --git a/Comp2Comp-main/comp2comp/io/io.py b/Comp2Comp-main/comp2comp/io/io.py new file mode 100644 index 0000000000000000000000000000000000000000..46a3d1b12441ee7f4ee58bbb47c075354028fe5e --- /dev/null +++ b/Comp2Comp-main/comp2comp/io/io.py @@ -0,0 +1,138 @@ +""" +@author: louisblankemeier +""" +import os +import shutil +from pathlib import Path +from typing import Dict, Union + +# import dicom2nifti +import dosma as dm +import pydicom +import SimpleITK as sitk + +from comp2comp.inference_class_base import InferenceClass + + +class DicomLoader(InferenceClass): + """Load a single dicom series.""" + + def __init__(self, input_path: Union[str, Path]): + super().__init__() + self.dicom_dir = Path(input_path) + self.dr = dm.DicomReader() + + def __call__(self, inference_pipeline) -> Dict: + medical_volume = self.dr.load( + self.dicom_dir, group_by=None, sort_by="InstanceNumber" + )[0] + return {"medical_volume": medical_volume} + + +class NiftiSaver(InferenceClass): + """Save dosma medical volume object to NIfTI file.""" + + def __init__(self): + super().__init__() + # self.output_dir = Path(output_path) + self.nw = dm.NiftiWriter() + + def __call__( + self, inference_pipeline, medical_volume: dm.MedicalVolume + ) -> Dict[str, Path]: + nifti_file = inference_pipeline.output_dir + self.nw.write(medical_volume, nifti_file) + return {"nifti_file": nifti_file} + + +class DicomFinder(InferenceClass): + """Find dicom files in a directory.""" + + def __init__(self, input_path: Union[str, Path]) -> Dict[str, Path]: + super().__init__() + self.input_path = Path(input_path) + + def __call__(self, inference_pipeline) -> Dict[str, Path]: + """Find dicom files in a directory. + + Args: + inference_pipeline (InferencePipeline): Inference pipeline. + + Returns: + Dict[str, Path]: Dictionary containing dicom files. + """ + dicom_files = [] + for file in self.input_path.glob("**/*.dcm"): + dicom_files.append(file) + inference_pipeline.dicom_file_paths = dicom_files + return {} + + +class DicomToNifti(InferenceClass): + """Convert dicom files to NIfTI files.""" + + def __init__(self, input_path: Union[str, Path], save=True): + super().__init__() + self.input_path = Path(input_path) + self.save = save + + def __call__(self, inference_pipeline): + if os.path.exists( + os.path.join( + inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz" + ) + ): + return {} + if hasattr(inference_pipeline, "medical_volume"): + return {} + output_dir = inference_pipeline.output_dir + segmentations_output_dir = os.path.join(output_dir, "segmentations") + os.makedirs(segmentations_output_dir, exist_ok=True) + + # if self.input_path is a folder + if self.input_path.is_dir(): + ds = dicom_series_to_nifti( + self.input_path, + output_file=os.path.join( + segmentations_output_dir, "converted_dcm.nii.gz" + ), + reorient_nifti=False, + ) + inference_pipeline.dicom_series_path = str(self.input_path) + inference_pipeline.dicom_ds = ds + elif str(self.input_path).endswith(".nii"): + shutil.copy( + self.input_path, + os.path.join(segmentations_output_dir, "converted_dcm.nii"), + ) + elif str(self.input_path).endswith(".nii.gz"): + shutil.copy( + self.input_path, + os.path.join(segmentations_output_dir, "converted_dcm.nii.gz"), + ) + + return {} + + +def series_selector(dicom_path): + ds = pydicom.filereader.dcmread(dicom_path) + image_type_list = list(ds.ImageType) + if not any("primary" in s.lower() for s in image_type_list): + raise ValueError("Not primary image type") + if not any("original" in s.lower() for s in image_type_list): + raise ValueError("Not original image type") + # if any("gsi" in s.lower() for s in image_type_list): + # raise ValueError("GSI image type") + if ds.ImageOrientationPatient != [1, 0, 0, 0, 1, 0]: + raise ValueError("Image orientation is not axial") + return ds + + +def dicom_series_to_nifti(input_path, output_file, reorient_nifti): + reader = sitk.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(str(input_path)) + ds = series_selector(dicom_names[0]) + reader.SetFileNames(dicom_names) + image = reader.Execute() + sitk.WriteImage(image, output_file) + return ds diff --git a/Comp2Comp-main/comp2comp/io/io_utils.py b/Comp2Comp-main/comp2comp/io/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8d5bbad4b66e8b6acbe46a0f29fd376ec3d304 --- /dev/null +++ b/Comp2Comp-main/comp2comp/io/io_utils.py @@ -0,0 +1,77 @@ +""" +@author: louisblankemeier +""" +import csv +import os + +import pydicom + + +def find_dicom_files(input_path): + dicom_series = [] + if not os.path.isdir(input_path): + dicom_series = [str(os.path.abspath(input_path))] + else: + for root, _, files in os.walk(input_path): + for file in files: + if file.endswith(".dcm") or file.endswith(".dicom"): + dicom_series.append(os.path.join(root, file)) + return dicom_series + + +def get_dicom_paths_and_num(path): + """ + Get all paths under a path that contain only dicom files. + Args: + path (str): Path to search. + Returns: + list: List of paths. + """ + dicom_paths = [] + for root, _, files in os.walk(path): + if len(files) > 0: + if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): + dicom_paths.append((root, len(files))) + + if len(dicom_paths) == 0: + raise ValueError("No scans were found in:\n" + path) + + return dicom_paths + + +def get_dicom_or_nifti_paths_and_num(path): + """Get all paths under a path that contain only dicom files or a nifti file. + Args: + path (str): Path to search. + + Returns: + list: List of paths. + """ + if path.endswith(".nii") or path.endswith(".nii.gz"): + return [(path, 1)] + dicom_nifti_paths = [] + for root, dirs, files in os.walk(path): + if len(files) > 0: + # if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): + dicom_nifti_paths.append((root, len(files))) + # else: + # for file in files: + # if file.endswith(".nii") or file.endswith(".nii.gz"): + # num_slices = 450 + # dicom_nifti_paths.append((os.path.join(root, file), num_slices)) + + return dicom_nifti_paths + + +def write_dicom_metadata_to_csv(ds, csv_filename): + with open(csv_filename, "w", newline="") as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(["Tag", "Keyword", "Value"]) + + for element in ds: + tag = element.tag + keyword = pydicom.datadict.keyword_for_tag(tag) + if keyword == "PixelData": + continue + value = str(element.value) + csvwriter.writerow([tag, keyword, value]) diff --git a/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py new file mode 100644 index 0000000000000000000000000000000000000000..d330b8c43d3df903b7791f3dc0b5370fef138111 --- /dev/null +++ b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py @@ -0,0 +1,95 @@ +import os +from pathlib import Path +from time import time +from typing import Union + +from totalsegmentator.libs import ( + download_pretrained_weights, + nostdout, + setup_nnunet, +) + +from comp2comp.inference_class_base import InferenceClass + + +class LiverSpleenPancreasSegmentation(InferenceClass): + """Organ segmentation.""" + + def __init__(self): + super().__init__() + # self.input_path = input_path + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + mv, seg = self.organ_seg( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + self.output_dir_segmentations + "organs.nii.gz", + inference_pipeline.model_dir, + ) + + inference_pipeline.segmentation = seg + inference_pipeline.medical_volume = mv + + return {} + + def organ_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run organ segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting organs...") + st = time() + os.environ["SCRATCH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [251] + + setup_nnunet() + download_pretrained_weights(task_id[0]) + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + seg, mvs = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag="None", + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=True, + test=0, + ) + end = time() + + # Log total time for spine segmentation + print(f"Total time for organ segmentation: {end-st:.2f}s.") + + return seg, mvs diff --git a/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..fae6f595dbf982e789d4fe81f3028d3366884a65 --- /dev/null +++ b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os + +import numpy as np + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.liver_spleen_pancreas.visualization_utils import ( + generate_liver_spleen_pancreas_report, + generate_slice_images, +) + + +class LiverSpleenPancreasVisualizer(InferenceClass): + def __init__(self): + super().__init__() + + self.unit_dict = { + "Volume": r"$\mathregular{cm^3}$", + "Mean": "HU", + "Median": "HU", + } + + self.class_nums = [1, 5, 10] + self.organ_names = ["liver", "spleen", "pancreas"] + + def __call__(self, inference_pipeline): + self.output_dir = inference_pipeline.output_dir + self.output_dir_images_organs = os.path.join(self.output_dir, "images/") + inference_pipeline.output_dir_images_organs_organs_organs = ( + self.output_dir_images_organs + ) + + if not os.path.exists(self.output_dir_images_organs): + os.makedirs(self.output_dir_images_organs) + + inference_pipeline.medical_volume_arr = np.flip( + inference_pipeline.medical_volume.get_fdata(), axis=1 + ) + inference_pipeline.segmentation_arr = np.flip( + inference_pipeline.segmentation.get_fdata(), axis=1 + ) + + inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[ + "pixdim" + ][1:4] + inference_pipeline.vol_per_pixel = np.prod( + inference_pipeline.pix_dims / 10 + ) # mm to cm for having ml/pixel. + + self.organ_metrics = generate_slice_images( + inference_pipeline.medical_volume_arr, + inference_pipeline.segmentation_arr, + self.class_nums, + self.unit_dict, + inference_pipeline.vol_per_pixel, + inference_pipeline.pix_dims, + self.output_dir_images_organs, + fontsize=24, + ) + + inference_pipeline.organ_metrics = self.organ_metrics + + generate_liver_spleen_pancreas_report( + self.output_dir_images_organs, self.organ_names + ) + + return {} + + +class LiverSpleenPancreasMetricsPrinter(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + results = inference_pipeline.organ_metrics + organs = list(results.keys()) + + name_dist = max([len(o) for o in organs]) + metrics = [] + for k in results[list(results.keys())[0]].keys(): + if k != "Organ": + metrics.append(k) + + units = ["cm^3", "HU", "HU"] + + header = ( + "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + "}") * len(metrics) + ) + header = header.format( + "Organ", *[m + "(" + u + ")" for m, u in zip(metrics, units)] + ) + + base_print = ( + "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + ".0f}") * len(metrics) + ) + + print("\n") + print(header) + + for organ in results.values(): + line = base_print.format(*organ.values()) + print(line) + + print("\n") + + output_dir = inference_pipeline.output_dir + self.output_dir_metrics_organs = os.path.join(output_dir, "metrics/") + + if not os.path.exists(self.output_dir_metrics_organs): + os.makedirs(self.output_dir_metrics_organs) + + header = ( + ",".join(["Organ"] + [m + "(" + u + ")" for m, u in zip(metrics, units)]) + + "\n" + ) + with open( + os.path.join( + self.output_dir_metrics_organs, "liver_spleen_pancreas_metrics.csv" + ), + "w", + ) as f: + f.write(header) + + for organ in results.values(): + line = ",".join([str(v) for v in organ.values()]) + "\n" + f.write(line) + + return {} diff --git a/Comp2Comp-main/comp2comp/liver_spleen_pancreas/visualization_utils.py b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..492a04944dadbadf829e565e111bae93b7f079a1 --- /dev/null +++ b/Comp2Comp-main/comp2comp/liver_spleen_pancreas/visualization_utils.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os + +import matplotlib.pyplot as plt +import numpy as np +import scipy +from matplotlib.colors import ListedColormap +from PIL import Image + + +def extract_axial_mid_slice(ct, mask, crop=True): + slice_idx = np.argmax(mask.sum(axis=(0, 1))) + + ct_slice_z = np.transpose(ct[:, :, slice_idx], axes=(1, 0)) + mask_slice_z = np.transpose(mask[:, :, slice_idx], axes=(1, 0)) + + ct_slice_z = np.flip(ct_slice_z, axis=(0, 1)) + mask_slice_z = np.flip(mask_slice_z, axis=(0, 1)) + + if crop: + ct_range_x = np.where(ct_slice_z.max(axis=0) > -200)[0][[0, -1]] + + ct_slice_z = ct_slice_z[ + ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] + ] + mask_slice_z = mask_slice_z[ + ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] + ] + + return ct_slice_z, mask_slice_z + + +def extract_coronal_mid_slice(ct, mask, crop=True): + # find the slice with max coherent extent of the organ + coronary_extent = np.where(mask.sum(axis=(0, 2)))[0] + + max_extent = 0 + max_extent_idx = 0 + + for idx in coronary_extent: + label, num_features = scipy.ndimage.label(mask[:, idx, :]) + + if num_features > 1: + continue + else: + extent = len(np.where(label.sum(axis=1))[0]) + if extent > max_extent: + max_extent = extent + max_extent_idx = idx + + ct_slice_y = np.transpose(ct[:, max_extent_idx, :], axes=(1, 0)) + mask_slice_y = np.transpose(mask[:, max_extent_idx, :], axes=(1, 0)) + + ct_slice_y = np.flip(ct_slice_y, axis=1) + mask_slice_y = np.flip(mask_slice_y, axis=1) + + return ct_slice_y, mask_slice_y + + +def save_slice( + ct_slice, + mask_slice, + path, + figsize=(12, 12), + corner_text=None, + unit_dict=None, + aspect=1, + show=False, + xy_placement=None, + class_color=1, + fontsize=14, +): + # colormap for shown segmentations + color_array = plt.get_cmap("tab10")(range(10)) + color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:, :]), axis=0) + map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array) + + fig, axx = plt.subplots(1, figsize=figsize, frameon=False) + axx.imshow( + ct_slice, + cmap="gray", + vmin=-400, + vmax=400, + interpolation="spline36", + aspect=aspect, + origin="lower", + ) + axx.imshow( + mask_slice * class_color, + cmap=map_object_seg, + vmin=0, + vmax=9, + alpha=0.2, + interpolation="nearest", + aspect=aspect, + origin="lower", + ) + + plt.axis("off") + axx.axes.get_xaxis().set_visible(False) + axx.axes.get_yaxis().set_visible(False) + + y_size, x_size = ct_slice.shape + + if corner_text is not None: + bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5) + + texts = [] + for k, v in corner_text.items(): + if isinstance(v, str): + texts.append("{:<9}{}".format(k + ":", v)) + else: + unit = unit_dict[k] if k in unit_dict else "" + texts.append("{:<9}{:.0f} {}".format(k + ":", v, unit)) + + if xy_placement is None: + # get the extent of textbox, remove, and the plot again with correct position + t = axx.text( + 0.5, + 0.5, + "\n".join(texts), + color="white", + transform=axx.transAxes, + fontsize=fontsize, + family="monospace", + bbox=bbox_props, + va="top", + ha="left", + ) + xmin, xmax = t.get_window_extent().xmin, t.get_window_extent().xmax + xmin, xmax = axx.transAxes.inverted().transform((xmin, xmax)) + + xy_placement = [1 - (xmax - xmin) - (xmax - xmin) * 0.09, 0.975] + t.remove() + + axx.text( + xy_placement[0], + xy_placement[1], + "\n".join(texts), + color="white", + transform=axx.transAxes, + fontsize=fontsize, + family="monospace", + bbox=bbox_props, + va="top", + ha="left", + ) + + if show: + plt.show() + else: + fig.savefig(path, bbox_inches="tight", pad_inches=0) + plt.close(fig) + + +def slicedDilationOrErosion(input_mask, num_iteration, operation): + """ + Perform the dilation on the smallest slice that will fit the + segmentation + """ + margin = 2 if num_iteration is None else num_iteration + 1 + + # find the minimum volume enclosing the organ + x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] + x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin + y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] + y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin + z_idx = np.where(input_mask.sum(axis=(0, 1)))[0] + z_start, z_end = z_idx[0] - margin, z_idx[-1] + margin + + struct = scipy.ndimage.generate_binary_structure(3, 1) + struct = scipy.ndimage.iterate_structure(struct, num_iteration) + + if operation == "dilate": + mask_slice = scipy.ndimage.binary_dilation( + input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct + ).astype(np.int8) + elif operation == "erode": + mask_slice = scipy.ndimage.binary_erosion( + input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct + ).astype(np.int8) + + output_mask = input_mask.copy() + + output_mask[x_start:x_end, y_start:y_end, z_start:z_end] = mask_slice + + return output_mask + + +def extract_organ_metrics( + ct, all_masks, class_num=None, vol_per_pixel=None, erode_mask=True +): + if erode_mask: + eroded_mask = slicedDilationOrErosion( + input_mask=(all_masks == class_num), num_iteration=3, operation="erode" + ) + ct_organ_vals = ct[eroded_mask == 1] + else: + ct_organ_vals = ct[all_masks == class_num] + + results = {} + + # in ml + organ_vol = (all_masks == class_num).sum() * vol_per_pixel + organ_mean = ct_organ_vals.mean() + organ_median = np.median(ct_organ_vals) + + results = { + "Organ": class_map_part_organs[class_num], + "Volume": organ_vol, + "Mean": organ_mean, + "Median": organ_median, + } + + return results + + +def generate_slice_images( + ct, + all_masks, + class_nums, + unit_dict, + vol_per_pixel, + pix_dims, + root, + fontsize=20, + show=False, +): + all_results = {} + + colors = [1, 3, 4] + + for i, c_num in enumerate(class_nums): + organ_name = class_map_part_organs[c_num] + + axial_path = os.path.join(root, organ_name.lower() + "_axial.png") + coronal_path = os.path.join(root, organ_name.lower() + "_coronal.png") + + ct_slice_z, liver_slice_z = extract_axial_mid_slice(ct, all_masks == c_num) + results = extract_organ_metrics( + ct, all_masks, class_num=c_num, vol_per_pixel=vol_per_pixel + ) + + save_slice( + ct_slice_z, + liver_slice_z, + axial_path, + figsize=(12, 12), + corner_text=results, + unit_dict=unit_dict, + class_color=colors[i], + fontsize=fontsize, + show=show, + ) + + ct_slice_y, liver_slice_y = extract_coronal_mid_slice(ct, all_masks == c_num) + + save_slice( + ct_slice_y, + liver_slice_y, + coronal_path, + figsize=(12, 12), + aspect=pix_dims[2] / pix_dims[1], + show=show, + class_color=colors[i], + ) + + all_results[results["Organ"]] = results + + if show: + return + + return all_results + + +def generate_liver_spleen_pancreas_report(root, organ_names): + axial_imgs = [ + Image.open(os.path.join(root, organ + "_axial.png")) for organ in organ_names + ] + coronal_imgs = [ + Image.open(os.path.join(root, organ + "_coronal.png")) for organ in organ_names + ] + + result_width = max( + sum([img.size[0] for img in axial_imgs]), + sum([img.size[0] for img in coronal_imgs]), + ) + result_height = max( + [a.size[1] + c.size[1] for a, c in zip(axial_imgs, coronal_imgs)] + ) + + result = Image.new("RGB", (result_width, result_height)) + + total_width = 0 + + for a_img, c_img in zip(axial_imgs, coronal_imgs): + a_width, a_height = a_img.size + c_width, c_height = c_img.size + + translate = (a_width - c_width) // 2 if a_width > c_width else 0 + + result.paste(im=a_img, box=(total_width, 0)) + result.paste(im=c_img, box=(translate + total_width, a_height)) + + total_width += a_width + + result.save(os.path.join(root, "liver_spleen_pancreas_report.png")) + + +# from https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/map_to_binary.py + +class_map_part_organs = { + 1: "Spleen", + 2: "Right Kidney", + 3: "Left Kidney", + 4: "Gallbladder", + 5: "Liver", + 6: "Stomach", + 7: "Aorta", + 8: "Inferior vena cava", + 9: "portal Vein and Splenic Vein", + 10: "Pancreas", + 11: "Right Adrenal Gland", + 12: "Left Adrenal Gland Left", + 13: "lung_upper_lobe_left", + 14: "lung_lower_lobe_left", + 15: "lung_upper_lobe_right", + 16: "lung_middle_lobe_right", + 17: "lung_lower_lobe_right", +} diff --git a/Comp2Comp-main/comp2comp/metrics/metrics.py b/Comp2Comp-main/comp2comp/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c07e3ebedbdc61928b7416b0f3dc9c1b4a205b --- /dev/null +++ b/Comp2Comp-main/comp2comp/metrics/metrics.py @@ -0,0 +1,156 @@ +from abc import ABC, abstractmethod +from typing import Callable, Sequence, Union + +import numpy as np + + +def flatten_non_category_dims( + xs: Union[np.ndarray, Sequence[np.ndarray]], category_dim: int = None +): + """Flattens all non-category dimensions into a single dimension. + + Args: + xs (ndarrays): Sequence of ndarrays with the same category dimension. + category_dim: The dimension/axis corresponding to different categories. + i.e. `C`. If `None`, behaves like `np.flatten(x)`. + + Returns: + ndarray: Shape (C, -1) if `category_dim` specified else shape (-1,) + """ + single_item = isinstance(xs, np.ndarray) + if single_item: + xs = [xs] + + if category_dim is not None: + dims = (xs[0].shape[category_dim], -1) + xs = (np.moveaxis(x, category_dim, 0).reshape(dims) for x in xs) + else: + xs = (x.flatten() for x in xs) + + if single_item: + return list(xs)[0] + else: + return xs + + +class Metric(Callable, ABC): + """Interface for new metrics. + + A metric should be implemented as a callable with explicitly defined + arguments. In other words, metrics should not have `**kwargs` or `**args` + options in the `__call__` method. + + While not explicitly constrained to the return type, metrics typically + return float value(s). The number of values returned corresponds to the + number of categories. + + * metrics should have different name() for different functionality. + * `category_dim` duck type if metric can process multiple categories at + once. + + To compute metrics: + + .. code-block:: python + + metric = Metric() + results = metric(...) + """ + + def __init__(self, units: str = ""): + self.units = units + + def name(self): + return type(self).__name__ + + def display_name(self): + """Name to use for pretty printing and display purposes.""" + name = self.name() + return "{} {}".format(name, self.units) if self.units else name + + @abstractmethod + def __call__(self, *args, **kwargs): + pass + + +class HounsfieldUnits(Metric): + FULL_NAME = "Hounsfield Unit" + + def __init__(self, units="hu"): + super().__init__(units) + + def __call__(self, mask, x, category_dim: int = None): + mask = mask.astype(np.bool) + if category_dim is None: + return np.mean(x[mask]) + + assert category_dim == -1 + num_classes = mask.shape[-1] + + return np.array([np.mean(x[mask[..., c]]) for c in range(num_classes)]) + + def name(self): + return self.FULL_NAME + + +class CrossSectionalArea(Metric): + def __call__(self, mask, spacing=None, category_dim: int = None): + pixel_area = np.prod(spacing) if spacing else 1 + mask = mask.astype(np.bool) + mask = flatten_non_category_dims(mask, category_dim) + + return pixel_area * np.count_nonzero(mask, -1) / 100.0 + + def name(self): + if self.units: + return "Cross-sectional Area ({})".format(self.units) + else: + return "Cross-sectional Area" + + +def manifest_to_map(manifest, model_type): + """Converts a manifest to a map of metric name to metric instance. + + Args: + manifest (dict): A dictionary of metric name to metric instance. + + Returns: + dict: A dictionary of metric name to metric instance. + """ + # TODO: hacky. Update this + figure_text_key = {} + for manifest_dict in manifest: + try: + key = manifest_dict["Level"] + except BaseException: + key = ".".join((manifest_dict["File"].split("/")[-1]).split(".")[:-1]) + muscle_hu = f"{manifest_dict['Hounsfield Unit (muscle)']:.2f}" + muscle_area = f"{manifest_dict['Cross-sectional Area (cm^2) (muscle)']:.2f}" + vat_hu = f"{manifest_dict['Hounsfield Unit (vat)']:.2f}" + vat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (vat)']:.2f}" + sat_hu = f"{manifest_dict['Hounsfield Unit (sat)']:.2f}" + sat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (sat)']:.2f}" + imat_hu = f"{manifest_dict['Hounsfield Unit (imat)']:.2f}" + imat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (imat)']:.2f}" + if model_type.model_name == "abCT_v0.0.1": + figure_text_key[key] = [ + muscle_hu, + muscle_area, + imat_hu, + imat_area, + vat_hu, + vat_area, + sat_hu, + sat_area, + ] + else: + figure_text_key[key] = [ + muscle_hu, + muscle_area, + vat_hu, + vat_area, + sat_hu, + sat_area, + imat_hu, + imat_area, + ] + return figure_text_key diff --git a/Comp2Comp-main/comp2comp/models/models.py b/Comp2Comp-main/comp2comp/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..b972f64cf17ce1e0181b180f6aa2a9b60ec06963 --- /dev/null +++ b/Comp2Comp-main/comp2comp/models/models.py @@ -0,0 +1,157 @@ +import enum +import os +from pathlib import Path +from typing import Dict, Sequence + +import wget +from keras.models import load_model + + +class Models(enum.Enum): + ABCT_V_0_0_1 = ( + 1, + "abCT_v0.0.1", + {"muscle": 0, "imat": 1, "vat": 2, "sat": 3}, + False, + ("soft", "bone", "custom"), + ) + + STANFORD_V_0_0_1 = ( + 2, + "stanford_v0.0.1", + # ("background", "muscle", "bone", "vat", "sat", "imat"), + # Category name mapped to channel index + {"muscle": 1, "vat": 3, "sat": 4, "imat": 5}, + True, + ("soft", "bone", "custom"), + ) + + STANFORD_V_0_0_2 = ( + 3, + "stanford_v0.0.2", + {"muscle": 4, "sat": 1, "vat": 2, "imat": 3}, + True, + ("soft", "bone", "custom"), + ) + TS_SPINE_FULL = ( + 4, + "ts_spine_full", + # Category name mapped to channel index + { + "L5": 18, + "L4": 19, + "L3": 20, + "L2": 21, + "L1": 22, + "T12": 23, + "T11": 24, + "T10": 25, + "T9": 26, + "T8": 27, + "T7": 28, + "T6": 29, + "T5": 30, + "T4": 31, + "T3": 32, + "T2": 33, + "T1": 34, + "C7": 35, + "C6": 36, + "C5": 37, + "C4": 38, + "C3": 39, + "C2": 40, + "C1": 41, + }, + False, + (), + ) + TS_SPINE = ( + 5, + "ts_spine", + # Category name mapped to channel index + # {"L5": 18, "L4": 19, "L3": 20, "L2": 21, "L1": 22, "T12": 23}, + {"L5": 27, "L4": 28, "L3": 29, "L2": 30, "L1": 31, "T12": 32}, + False, + (), + ) + STANFORD_SPINE_V_0_0_1 = ( + 6, + "stanford_spine_v0.0.1", + # Category name mapped to channel index + {"L5": 24, "L4": 23, "L3": 22, "L2": 21, "L1": 20, "T12": 19}, + False, + (), + ) + TS_HIP = ( + 7, + "ts_hip", + # Category name mapped to channel index + {"femur_left": 88, "femur_right": 89}, + False, + (), + ) + + def __new__( + cls, + value: int, + model_name: str, + categories: Dict[str, int], + use_softmax: bool, + windows: Sequence[str], + ): + obj = object.__new__(cls) + obj._value_ = value + + obj.model_name = model_name + obj.categories = categories + obj.use_softmax = use_softmax + obj.windows = windows + return obj + + def load_model(self, model_dir): + """Load the model from the models directory. + + Args: + logger (logging.Logger): Logger. + + Returns: + keras.models.Model: Model. + """ + try: + filename = Models.find_model_weights(self.model_name, model_dir) + except Exception: + print("Downloading muscle/fat model from hugging face") + Path(model_dir).mkdir(parents=True, exist_ok=True) + wget.download( + f"https://huggingface.co/stanfordmimi/stanford_abct_v0.0.1/resolve/main/{self.model_name}.h5", + out=os.path.join(model_dir, f"{self.model_name}.h5"), + ) + filename = Models.find_model_weights(self.model_name, model_dir) + print("") + + print("Loading muscle/fat model from {}".format(filename)) + return load_model(filename) + + @staticmethod + def model_from_name(model_name): + """Get the model enum from the model name. + + Args: + model_name (str): Model name. + + Returns: + Models: Model enum. + """ + for model in Models: + if model.model_name == model_name: + return model + return None + + @staticmethod + def find_model_weights(file_name, model_dir): + for root, _, files in os.walk(model_dir): + for file in files: + if file.startswith(file_name): + filename = os.path.join(root, file) + return filename diff --git a/Comp2Comp-main/comp2comp/muscle_adipose_tissue/data.py b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/data.py new file mode 100644 index 0000000000000000000000000000000000000000..54843bf823503d1505b2c62fd083b69278ada84f --- /dev/null +++ b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/data.py @@ -0,0 +1,214 @@ +import math +from typing import List, Sequence + +import keras.utils as k_utils +import numpy as np +import pydicom +from keras.utils.data_utils import OrderedEnqueuer +from tqdm import tqdm + + +def parse_windows(windows): + """Parse windows provided by the user. + + These windows can either be strings corresponding to popular windowing + thresholds for CT or tuples of (upper, lower) bounds. + + Args: + windows (list): List of strings or tuples. + + Returns: + list: List of tuples of (upper, lower) bounds. + """ + windowing = { + "soft": (400, 50), + "bone": (1800, 400), + "liver": (150, 30), + "spine": (250, 50), + "custom": (500, 50), + } + vals = [] + for w in windows: + if isinstance(w, Sequence) and len(w) == 2: + assert_msg = "Expected tuple of (lower, upper) bound" + assert len(w) == 2, assert_msg + assert isinstance(w[0], (float, int)), assert_msg + assert isinstance(w[1], (float, int)), assert_msg + assert w[0] < w[1], assert_msg + vals.append(w) + continue + + if w not in windowing: + raise KeyError("Window {} not found".format(w)) + window_width = windowing[w][0] + window_level = windowing[w][1] + upper = window_level + window_width / 2 + lower = window_level - window_width / 2 + + vals.append((lower, upper)) + + return tuple(vals) + + +def _window(xs, bounds): + """Apply windowing to an array of CT images. + + Args: + xs (ndarray): NxHxW + bounds (tuple): (lower, upper) bounds + + Returns: + ndarray: Windowed images. + """ + + imgs = [] + for lb, ub in bounds: + imgs.append(np.clip(xs, a_min=lb, a_max=ub)) + + if len(imgs) == 1: + return imgs[0] + elif xs.shape[-1] == 1: + return np.concatenate(imgs, axis=-1) + else: + return np.stack(imgs, axis=-1) + + +class Dataset(k_utils.Sequence): + def __init__(self, files: List[str], batch_size: int = 16, windows=None): + self._files = files + self._batch_size = batch_size + self.windows = windows + + def __len__(self): + return math.ceil(len(self._files) / self._batch_size) + + def __getitem__(self, idx): + files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size] + dcms = [pydicom.read_file(f, force=True) for f in files] + + xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms] + + params = [ + {"spacing": header.PixelSpacing, "image": x} for header, x in zip(dcms, xs) + ] + + # Preprocess xs via windowing. + xs = np.stack(xs, axis=0) + if self.windows: + xs = _window(xs, parse_windows(self.windows)) + else: + xs = xs[..., np.newaxis] + + return xs, params + + +def _swap_muscle_imap(xs, ys, muscle_idx: int, imat_idx: int, threshold=-30.0): + """ + If pixel labeled as muscle but has HU < threshold, change label to imat. + + Args: + xs (ndarray): NxHxWxC + ys (ndarray): NxHxWxC + muscle_idx (int): Index of the muscle label. + imat_idx (int): Index of the imat label. + threshold (float): Threshold for HU value. + + Returns: + ndarray: Segmentation mask with swapped labels. + """ + labels = ys.copy() + + muscle_mask = (labels[..., muscle_idx] > 0.5).astype(int) + imat_mask = labels[..., imat_idx] + + imat_mask[muscle_mask.astype(np.bool) & (xs < threshold)] = 1 + muscle_mask[xs < threshold] = 0 + + labels[..., muscle_idx] = muscle_mask + labels[..., imat_idx] = imat_mask + + return labels + + +def postprocess(xs: np.ndarray, ys: np.ndarray): + """Built-in post-processing. + + TODO: Make this configurable. + + Args: + xs (ndarray): NxHxW + ys (ndarray): NxHxWxC + params (dictionary): Post-processing parameters. Must contain + "categories". + + Returns: + ndarray: Post-processed labels. + """ + + # Add another channel full of zeros to ys + ys = np.concatenate([ys, np.zeros_like(ys[..., :1])], axis=-1) + + # If muscle hu is < -30, assume it is imat. + + """ + if "muscle" in categories and "imat" in categories: + ys = _swap_muscle_imap( + xs, + ys, + muscle_idx=categories["muscle"], + imat_idx=categories["imat"], + ) + """ + + return ys + + +def predict( + model, + dataset: Dataset, + batch_size: int = 16, + num_workers: int = 1, + max_queue_size: int = 10, + use_multiprocessing: bool = False, +): + """Predict segmentation masks for a dataset. + + Args: + model (keras.Model): Model to use for prediction. + dataset (Dataset): Dataset to predict on. + batch_size (int): Batch size. + num_workers (int): Number of workers. + max_queue_size (int): Maximum queue size. + use_multiprocessing (bool): Use multiprocessing. + use_postprocessing (bool): Use built-in post-processing. + postprocessing_params (dict): Post-processing parameters. + + Returns: + List: List of segmentation masks. + """ + + if num_workers > 0: + enqueuer = OrderedEnqueuer( + dataset, use_multiprocessing=use_multiprocessing, shuffle=False + ) + enqueuer.start(workers=num_workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + else: + output_generator = iter(dataset) + + num_scans = len(dataset) + xs = [] + ys = [] + params = [] + for _ in tqdm(range(num_scans)): + x, p_dicts = next(output_generator) + y = model.predict(x, batch_size=batch_size) + + image = np.stack([out["image"] for out in p_dicts], axis=0) + y = postprocess(image, y) + + params.extend(p_dicts) + xs.extend([x[i, ...] for i in range(len(x))]) + ys.extend([y[i, ...] for i in range(len(y))]) + + return xs, ys, params diff --git a/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py new file mode 100644 index 0000000000000000000000000000000000000000..a83b576ae2e005d0ac0d81b8ba856017a8f11606 --- /dev/null +++ b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py @@ -0,0 +1,445 @@ +import os +import zipfile +from pathlib import Path +from time import perf_counter +from typing import List, Union + +import cv2 +import h5py +import nibabel as nib +import numpy as np +import pandas as pd +import wget +from keras import backend as K +from tqdm import tqdm + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.metrics.metrics import CrossSectionalArea, HounsfieldUnits +from comp2comp.models.models import Models +from comp2comp.muscle_adipose_tissue.data import Dataset, predict + + +class MuscleAdiposeTissueSegmentation(InferenceClass): + """Muscle adipose tissue segmentation class.""" + + def __init__(self, batch_size: int, model_name: str, model_dir: str = None): + super().__init__() + self.batch_size = batch_size + self.model_name = model_name + self.model_type = Models.model_from_name(model_name) + + def forward_pass_2d(self, files): + dataset = Dataset(files, windows=self.model_type.windows) + num_workers = 1 + + print("Computing segmentation masks using {}...".format(self.model_name)) + start_time = perf_counter() + _, preds, results = predict( + self.model, + dataset, + num_workers=num_workers, + use_multiprocessing=num_workers > 1, + batch_size=self.batch_size, + ) + K.clear_session() + print( + f"Completed {len(files)} segmentations in {(perf_counter() - start_time):.2f} seconds." + ) + for i in range(len(results)): + results[i]["preds"] = preds[i] + return results + + def download_muscle_adipose_tissue_model(self, model_dir: Union[str, Path]): + download_dir = Path( + os.path.join( + model_dir, + ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1", + ) + ) + all_path = download_dir / "all" + if not os.path.exists(all_path): + download_dir.mkdir(parents=True, exist_ok=True) + wget.download( + "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/all.zip", + out=os.path.join(download_dir, "all.zip"), + ) + with zipfile.ZipFile(os.path.join(download_dir, "all.zip"), "r") as zip_ref: + zip_ref.extractall(download_dir) + os.remove(os.path.join(download_dir, "all.zip")) + wget.download( + "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/plans.pkl", + out=os.path.join(download_dir, "plans.pkl"), + ) + print("Muscle and adipose tissue model downloaded.") + else: + print("Muscle and adipose tissue model already downloaded.") + + def __call__(self, inference_pipeline): + inference_pipeline.muscle_adipose_tissue_model_type = self.model_type + inference_pipeline.muscle_adipose_tissue_model_name = self.model_name + + if self.model_name == "stanford_v0.0.2": + self.download_muscle_adipose_tissue_model(inference_pipeline.model_dir) + nifti_path = os.path.join( + inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz" + ) + output_path = os.path.join( + inference_pipeline.output_dir, + "segmentations", + "converted_dcm_seg.nii.gz", + ) + + from nnunet.inference import predict + + predict.predict_cases( + model=os.path.join( + inference_pipeline.model_dir, + ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1", + ), + list_of_lists=[[nifti_path]], + output_filenames=[output_path], + folds="all", + save_npz=False, + num_threads_preprocessing=8, + num_threads_nifti_save=8, + segs_from_prev_stage=None, + do_tta=False, + mixed_precision=True, + overwrite_existing=False, + all_in_gpu=False, + step_size=0.5, + checkpoint_name="model_final_checkpoint", + segmentation_export_kwargs=None, + ) + + image_nib = nib.load(nifti_path) + image_nib = nib.as_closest_canonical(image_nib) + image = image_nib.get_fdata() + pred = nib.load(output_path) + pred = nib.as_closest_canonical(pred) + pred = pred.get_fdata() + + images = [image[:, :, i] for i in range(image.shape[-1])] + preds = [pred[:, :, i] for i in range(pred.shape[-1])] + + # flip both axes and transpose + images = [np.flip(np.flip(image, axis=0), axis=1).T for image in images] + preds = [np.flip(np.flip(pred, axis=0), axis=1).T for pred in preds] + + spacings = [ + image_nib.header.get_zooms()[0:2] for i in range(image.shape[-1]) + ] + + categories = self.model_type.categories + + # for each image in images, convert to one hot encoding + masks = [] + for pred in preds: + mask = np.zeros((pred.shape[0], pred.shape[1], 4)) + for i, category in enumerate(categories): + mask[:, :, i] = pred == categories[category] + mask = mask.astype(np.uint8) + masks.append(mask) + return {"images": images, "preds": masks, "spacings": spacings} + + else: + dicom_file_paths = inference_pipeline.dicom_file_paths + # if dicom_file_names not an attribute of inference_pipeline, add it + if not hasattr(inference_pipeline, "dicom_file_names"): + inference_pipeline.dicom_file_names = [ + dicom_file_path.stem for dicom_file_path in dicom_file_paths + ] + self.model = self.model_type.load_model(inference_pipeline.model_dir) + + results = self.forward_pass_2d(dicom_file_paths) + images = [] + for result in results: + images.append(result["image"]) + preds = [] + for result in results: + preds.append(result["preds"]) + spacings = [] + for result in results: + spacings.append(result["spacing"]) + + return {"images": images, "preds": preds, "spacings": spacings} + + +class MuscleAdiposeTissuePostProcessing(InferenceClass): + """Post-process muscle and adipose tissue segmentation.""" + + def __init__(self): + super().__init__() + + def preds_to_mask(self, preds): + """Convert model predictions to a mask. + + Args: + preds (np.ndarray): Model predictions. + + Returns: + np.ndarray: Mask. + """ + if self.use_softmax: + # softmax + labels = np.zeros_like(preds, dtype=np.uint8) + l_argmax = np.argmax(preds, axis=-1) + for c in range(labels.shape[-1]): + labels[l_argmax == c, c] = 1 + return labels.astype(np.bool) + else: + # sigmoid + return preds >= 0.5 + + def __call__(self, inference_pipeline, images, preds, spacings): + """Post-process muscle and adipose tissue segmentation.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.use_softmax = self.model_type.use_softmax + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + return self.post_process(images, preds, spacings) + + def remove_small_objects(self, mask, min_size=10): + mask = mask.astype(np.uint8) + components, output, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + sizes = stats[1:, -1] + mask = np.zeros((output.shape)) + for i in range(0, components - 1): + if sizes[i] >= min_size: + mask[output == i + 1] = 1 + return mask + + def post_process( + self, + images, + preds, + spacings, + ): + categories = self.model_type.categories + + start_time = perf_counter() + + if self.model_name == "stanford_v0.0.2": + masks = preds + else: + masks = [self.preds_to_mask(p) for p in preds] + + for i, _ in enumerate(masks): + # Keep only channels from the model_type categories dict + masks[i] = np.squeeze(masks[i]) + + masks = self.fill_holes(masks) + + cats = list(categories.keys()) + + file_idx = 0 + for mask, image in tqdm(zip(masks, images), total=len(masks)): + muscle_mask = mask[..., cats.index("muscle")] + imat_mask = mask[..., cats.index("imat")] + imat_mask = ( + np.logical_and( + (image * muscle_mask) <= -30, (image * muscle_mask) >= -190 + ) + ).astype(int) + imat_mask = self.remove_small_objects(imat_mask) + mask[..., cats.index("imat")] += imat_mask + mask[..., cats.index("muscle")][imat_mask == 1] = 0 + masks[file_idx] = mask + images[file_idx] = image + file_idx += 1 + + print( + f"Completed post-processing in {(perf_counter() - start_time):.2f} seconds." + ) + + return {"images": images, "masks": masks, "spacings": spacings} + + # function that fills in holes in a segmentation mask + def _fill_holes(self, mask: np.ndarray, mask_id: int): + """Fill in holes in a segmentation mask. + + Args: + mask (ndarray): NxHxW + mask_id (int): Label of the mask. + + Returns: + ndarray: Filled mask. + """ + int_mask = ((1 - mask) > 0.5).astype(np.int8) + components, output, stats, _ = cv2.connectedComponentsWithStats( + int_mask, connectivity=8 + ) + sizes = stats[1:, -1] + components = components - 1 + # Larger threshold for SAT + # TODO make this configurable / parameter + if mask_id == 2: + min_size = 200 + else: + # min_size = 50 # Smaller threshold for everything else + min_size = 20 + img_out = np.ones_like(mask) + for i in range(0, components): + if sizes[i] > min_size: + img_out[output == i + 1] = 0 + return img_out + + def fill_holes(self, ys: List): + """Take an array of size NxHxWxC and for each channel fill in holes. + + Args: + ys (list): List of segmentation masks. + """ + segs = [] + for n in range(len(ys)): + ys_out = [ + self._fill_holes(ys[n][..., i], i) for i in range(ys[n].shape[-1]) + ] + segs.append(np.stack(ys_out, axis=2).astype(float)) + + return segs + + +class MuscleAdiposeTissueComputeMetrics(InferenceClass): + """Compute muscle and adipose tissue metrics.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, images, masks, spacings): + """Compute muscle and adipose tissue metrics.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + metrics = self.compute_metrics_all(images, masks, spacings) + return metrics + + def compute_metrics_all(self, images, masks, spacings): + """Compute metrics for all images and masks. + + Args: + images (List[np.ndarray]): Images. + masks (List[np.ndarray]): Masks. + + Returns: + Dict: Results. + """ + results = [] + for image, mask, spacing in zip(images, masks, spacings): + results.append(self.compute_metrics(image, mask, spacing)) + return {"images": images, "results": results} + + def compute_metrics(self, x, mask, spacing): + """Compute results for a given segmentation.""" + categories = self.model_type.categories + + hu = HounsfieldUnits() + csa_units = "cm^2" if spacing else "" + csa = CrossSectionalArea(csa_units) + + hu_vals = hu(mask, x, category_dim=-1) + csa_vals = csa(mask=mask, spacing=spacing, category_dim=-1) + + # check if any values are nan and replace with 0 + hu_vals = np.nan_to_num(hu_vals) + csa_vals = np.nan_to_num(csa_vals) + + assert mask.shape[-1] == len( + categories + ), "{} categories found in mask, " "but only {} categories specified".format( + mask.shape[-1], len(categories) + ) + + results = { + cat: { + "mask": mask[..., idx], + hu.name(): hu_vals[idx], + csa.name(): csa_vals[idx], + } + for idx, cat in enumerate(categories.keys()) + } + return results + + +class MuscleAdiposeTissueH5Saver(InferenceClass): + """Save results to an HDF5 file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, results): + """Save results to an HDF5 file.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + self.output_dir = inference_pipeline.output_dir + self.h5_output_dir = os.path.join(self.output_dir, "segmentations") + os.makedirs(self.h5_output_dir, exist_ok=True) + self.dicom_file_paths = inference_pipeline.dicom_file_paths + self.dicom_file_names = inference_pipeline.dicom_file_names + self.save_results(results) + return {"results": results} + + def save_results(self, results): + """Save results to an HDF5 file.""" + categories = self.model_type.categories + cats = list(categories.keys()) + + for i, result in enumerate(results): + file_name = self.dicom_file_names[i] + with h5py.File( + os.path.join(self.h5_output_dir, file_name + ".h5"), "w" + ) as f: + for cat in cats: + mask = result[cat]["mask"] + f.create_dataset(name=cat, data=np.array(mask, dtype=np.uint8)) + + +class MuscleAdiposeTissueMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline, results): + """Save metrics to a CSV file.""" + self.model_type = inference_pipeline.muscle_adipose_tissue_model_type + self.model_name = inference_pipeline.muscle_adipose_tissue_model_name + self.output_dir = inference_pipeline.output_dir + self.csv_output_dir = os.path.join(self.output_dir, "metrics") + os.makedirs(self.csv_output_dir, exist_ok=True) + self.dicom_file_paths = inference_pipeline.dicom_file_paths + self.dicom_file_names = inference_pipeline.dicom_file_names + self.save_results(results) + return {} + + def save_results(self, results): + """Save results to a CSV file.""" + self.model_type.categories + df = pd.DataFrame( + columns=[ + "Level", + "Index", + "Muscle HU", + "Muscle CSA (cm^2)", + "SAT HU", + "SAT CSA (cm^2)", + "VAT HU", + "VAT CSA (cm^2)", + "IMAT HU", + "IMAT CSA (cm^2)", + ] + ) + + for i, result in enumerate(results): + row = [] + row.append(self.dicom_file_names[i]) + row.append(self.dicom_file_paths[i]) + for cat in result: + row.append(result[cat]["Hounsfield Unit"]) + row.append(result[cat]["Cross-sectional Area (cm^2)"]) + df.loc[i] = row + df = df.iloc[::-1] + df.to_csv( + os.path.join(self.csv_output_dir, "muscle_adipose_tissue_metrics.csv"), + index=False, + ) diff --git a/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..56bf8420fa36f1d55f3ef23a4aaaa3b8b1d1a167 --- /dev/null +++ b/Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py @@ -0,0 +1,181 @@ +""" +@author: louisblankemeier +""" + +import os +from pathlib import Path + +import numpy as np + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.visualization.detectron_visualizer import Visualizer + + +class MuscleAdiposeTissueVisualizer(InferenceClass): + def __init__(self): + super().__init__() + + self._spine_colors = { + "L5": [255, 0, 0], + "L4": [0, 255, 0], + "L3": [255, 255, 0], + "L2": [255, 128, 0], + "L1": [0, 255, 255], + "T12": [255, 0, 255], + } + + self._muscle_fat_colors = { + "muscle": [255, 136, 133], + "imat": [154, 135, 224], + "vat": [140, 197, 135], + "sat": [246, 190, 129], + } + + self._SPINE_TEXT_OFFSET_FROM_TOP = 10.0 + self._SPINE_TEXT_OFFSET_FROM_RIGHT = 63.0 + self._SPINE_TEXT_VERTICAL_SPACING = 14.0 + + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING = 40.0 + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING = 14.0 + self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP = 22.0 + self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT = 181.0 + + def __call__(self, inference_pipeline, images, results): + self.output_dir = inference_pipeline.output_dir + self.dicom_file_names = inference_pipeline.dicom_file_names + # if spine is an attribute of the inference pipeline, use it + if not hasattr(inference_pipeline, "spine"): + spine = False + else: + spine = True + + for i, (image, result) in enumerate(zip(images, results)): + # now, result is a dict with keys for each tissue + dicom_file_name = self.dicom_file_names[i] + self.save_binary_segmentation_overlay(image, result, dicom_file_name, spine) + # pass along for next class in pipeline + return {"results": results} + + def save_binary_segmentation_overlay(self, image, result, dicom_file_name, spine): + file_name = dicom_file_name + ".png" + img_in = image + assert img_in.shape == (512, 512), "Image shape is not 512 x 512" + + img_in = np.clip(img_in, -300, 1800) + img_in = self.normalize_img(img_in) * 255.0 + + # Create the folder to save the images + images_base_path = Path(self.output_dir) / "images" + images_base_path.mkdir(exist_ok=True) + + text_start_vertical_offset = self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP + + img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1)) + img_rgb = np.tile(img_in, (1, 1, 3)) + + vis = Visualizer(img_rgb) + vis.draw_text( + text="Density (HU)", + position=( + img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63, + text_start_vertical_offset, + ), + color=[1, 1, 1], + font_size=9, + horizontal_alignment="left", + ) + vis.draw_text( + text="Area (CM²)", + position=( + img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63, + text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, + ), + color=[1, 1, 1], + font_size=9, + horizontal_alignment="left", + ) + + if spine: + spine_color = np.array(self._spine_colors[dicom_file_name]) / 255.0 + vis.draw_box( + box_coord=(1, 1, img_in.shape[0] - 1, img_in.shape[1] - 1), + alpha=1, + edge_color=spine_color, + ) + # draw the level T12 - L5 in the upper left corner + if dicom_file_name == "T12": + position = (40, 15) + else: + position = (30, 15) + vis.draw_text( + text=dicom_file_name, position=position, color=spine_color, font_size=24 + ) + + for idx, tissue in enumerate(result.keys()): + alpha_val = 0.9 + color = np.array(self._muscle_fat_colors[tissue]) / 255.0 + edge_color = color + mask = result[tissue]["mask"] + + vis.draw_binary_mask( + mask, + color=color, + edge_color=edge_color, + alpha=alpha_val, + area_threshold=0, + ) + + hu_val = round(result[tissue]["Hounsfield Unit"]) + area_val = round(result[tissue]["Cross-sectional Area (cm^2)"]) + + vis.draw_text( + text=tissue, + position=( + mask.shape[1] + - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT + + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), + text_start_vertical_offset - self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, + ), + color=color, + font_size=9, + horizontal_alignment="center", + ) + + vis.draw_text( + text=hu_val, + position=( + mask.shape[1] + - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT + + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), + text_start_vertical_offset, + ), + color=color, + font_size=9, + horizontal_alignment="center", + ) + vis.draw_text( + text=area_val, + position=( + mask.shape[1] + - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT + + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), + text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, + ), + color=color, + font_size=9, + horizontal_alignment="center", + ) + + vis_obj = vis.get_output() + vis_obj.save(os.path.join(images_base_path, file_name)) + + def normalize_img(self, img: np.ndarray) -> np.ndarray: + """Normalize the image. + + Args: + img (np.ndarray): Input image. + + Returns: + np.ndarray: Normalized image. + """ + return (img - img.min()) / (img.max() - img.min()) diff --git a/Comp2Comp-main/comp2comp/spine/spine.py b/Comp2Comp-main/comp2comp/spine/spine.py new file mode 100644 index 0000000000000000000000000000000000000000..42c8fd96754b35b309f7ecde4b2b535aecff31f2 --- /dev/null +++ b/Comp2Comp-main/comp2comp/spine/spine.py @@ -0,0 +1,483 @@ +""" +@author: louisblankemeier +""" + +import math +import os +import shutil +import zipfile +from pathlib import Path +from time import time +from typing import Union + +import nibabel as nib +import numpy as np +import pandas as pd +import wget +from PIL import Image +from totalsegmentatorv2.python_api import totalsegmentator + +from comp2comp.inference_class_base import InferenceClass +from comp2comp.io import io_utils +from comp2comp.models.models import Models +from comp2comp.spine import spine_utils +from comp2comp.visualization.dicom import to_dicom + +# from totalsegmentator.libs import ( +# download_pretrained_weights, +# nostdout, +# setup_nnunet, +# ) + + + + +class SpineSegmentation(InferenceClass): + """Spine segmentation.""" + + def __init__(self, model_name, save=True): + super().__init__() + self.model_name = model_name + self.save_segmentations = save + + def __call__(self, inference_pipeline): + # inference_pipeline.dicom_series_path = self.input_path + self.output_dir = inference_pipeline.output_dir + self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") + if not os.path.exists(self.output_dir_segmentations): + os.makedirs(self.output_dir_segmentations) + + self.model_dir = inference_pipeline.model_dir + + # seg, mv = self.spine_seg( + # os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + # self.output_dir_segmentations + "spine.nii.gz", + # inference_pipeline.model_dir, + # ) + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir + + seg = totalsegmentator( + input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), + task_ids=[292], + ml=True, + nr_thr_resamp=1, + nr_thr_saving=6, + fast=False, + nora_tag="None", + preview=False, + task="total", + # roi_subset=[ + # "vertebrae_T12", + # "vertebrae_L1", + # "vertebrae_L2", + # "vertebrae_L3", + # "vertebrae_L4", + # "vertebrae_L5", + # ], + roi_subset=None, + statistics=False, + radiomics=False, + crop_path=None, + body_seg=False, + force_split=False, + output_type="nifti", + quiet=False, + verbose=False, + test=0, + skip_saving=True, + device="gpu", + license_number=None, + statistics_exclude_masks_at_border=True, + no_derived_masks=False, + v1_order=False, + ) + mv = nib.load( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") + ) + + # inference_pipeline.segmentation = nib.load( + # os.path.join(self.output_dir_segmentations, "segmentation.nii") + # ) + inference_pipeline.segmentation = seg + inference_pipeline.medical_volume = mv + inference_pipeline.save_segmentations = self.save_segmentations + return {} + + def setup_nnunet_c2c(self, model_dir: Union[str, Path]): + """Adapted from TotalSegmentator.""" + + model_dir = Path(model_dir) + config_dir = model_dir / Path("." + self.model_name) + (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( + exist_ok=True, parents=True + ) + (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) + weights_dir = config_dir / "nnunet/results" + self.weights_dir = weights_dir + + os.environ["nnUNet_raw_data_base"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["nnUNet_preprocessed"] = str( + weights_dir + ) # not needed, just needs to be an existing directory + os.environ["RESULTS_FOLDER"] = str(weights_dir) + + def download_spine_model(self, model_dir: Union[str, Path]): + download_dir = Path( + os.path.join( + self.weights_dir, + "nnUNet/3d_fullres/Task252_Spine/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", + ) + ) + fold_0_path = download_dir / "fold_0" + if not os.path.exists(fold_0_path): + download_dir.mkdir(parents=True, exist_ok=True) + wget.download( + "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/fold_0.zip", + out=os.path.join(download_dir, "fold_0.zip"), + ) + with zipfile.ZipFile( + os.path.join(download_dir, "fold_0.zip"), "r" + ) as zip_ref: + zip_ref.extractall(download_dir) + os.remove(os.path.join(download_dir, "fold_0.zip")) + wget.download( + "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/plans.pkl", + out=os.path.join(download_dir, "plans.pkl"), + ) + print("Spine model downloaded.") + else: + print("Spine model already downloaded.") + + def spine_seg( + self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir + ): + """Run spine segmentation. + + Args: + input_path (Union[str, Path]): Input path. + output_path (Union[str, Path]): Output path. + """ + + print("Segmenting spine...") + st = time() + os.environ["SCRATCH"] = self.model_dir + os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir + + # Setup nnunet + model = "3d_fullres" + folds = [0] + trainer = "nnUNetTrainerV2_ep4000_nomirror" + crop_path = None + task_id = [252] + + if self.model_name == "ts_spine": + setup_nnunet() + download_pretrained_weights(task_id[0]) + elif self.model_name == "stanford_spine_v0.0.1": + self.setup_nnunet_c2c(model_dir) + self.download_spine_model(model_dir) + else: + raise ValueError("Invalid model name.") + + if not self.save_segmentations: + output_path = None + + from totalsegmentator.nnunet import nnUNet_predict_image + + with nostdout(): + img, seg = nnUNet_predict_image( + input_path, + output_path, + task_id, + model=model, + folds=folds, + trainer=trainer, + tta=False, + multilabel_image=True, + resample=1.5, + crop=None, + crop_path=crop_path, + task_name="total", + nora_tag="None", + preview=False, + nr_threads_resampling=1, + nr_threads_saving=6, + quiet=False, + verbose=False, + test=0, + ) + end = time() + + # Log total time for spine segmentation + print(f"Total time for spine segmentation: {end-st:.2f}s.") + + if self.model_name == "stanford_spine_v0.0.1": + seg_data = seg.get_fdata() + # subtract 17 from seg values except for 0 + seg_data = np.where(seg_data == 0, 0, seg_data - 17) + seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) + + return seg, img + + +class AxialCropper(InferenceClass): + """Crop the CT image (medical_volume) and segmentation based on user-specified + lower and upper levels of the spine. + """ + + def __init__(self, lower_level: str = "L5", upper_level: str = "L1", save=True): + """ + Args: + lower_level (str, optional): Lower level of the spine. Defaults to "L5". + upper_level (str, optional): Upper level of the spine. Defaults to "L1". + save (bool, optional): Save cropped image and segmentation. Defaults to True. + + Raises: + ValueError: If lower_level or upper_level is not a valid spine level. + """ + super().__init__() + self.lower_level = lower_level + self.upper_level = upper_level + ts_spine_full_model = Models.model_from_name("ts_spine_full") + categories = ts_spine_full_model.categories + try: + self.lower_level_index = categories[self.lower_level] + self.upper_level_index = categories[self.upper_level] + except KeyError: + raise ValueError("Invalid spine level.") from None + self.save = save + + def __call__(self, inference_pipeline): + """ + First dim goes from L to R. + Second dim goes from P to A. + Third dim goes from I to S. + """ + segmentation = inference_pipeline.segmentation + segmentation_data = segmentation.get_fdata() + upper_level_index = np.where(segmentation_data == self.upper_level_index)[ + 2 + ].max() + lower_level_index = np.where(segmentation_data == self.lower_level_index)[ + 2 + ].min() + segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index] + inference_pipeline.segmentation = segmentation + + medical_volume = inference_pipeline.medical_volume + medical_volume = medical_volume.slicer[ + :, :, lower_level_index:upper_level_index + ] + inference_pipeline.medical_volume = medical_volume + + if self.save: + nib.save( + segmentation, + os.path.join( + inference_pipeline.output_dir, "segmentations", "spine.nii.gz" + ), + ) + nib.save( + medical_volume, + os.path.join( + inference_pipeline.output_dir, + "segmentations", + "converted_dcm.nii.gz", + ), + ) + return {} + + +class SpineComputeROIs(InferenceClass): + def __init__(self, spine_model): + super().__init__() + self.spine_model_name = spine_model + self.spine_model_type = Models.model_from_name(self.spine_model_name) + + def __call__(self, inference_pipeline): + # Compute ROIs + inference_pipeline.spine_model_type = self.spine_model_type + + (spine_hus, rois, segmentation_hus, centroids_3d) = spine_utils.compute_rois( + inference_pipeline.segmentation, + inference_pipeline.medical_volume, + self.spine_model_type, + ) + + inference_pipeline.spine_hus = spine_hus + inference_pipeline.segmentation_hus = segmentation_hus + inference_pipeline.rois = rois + inference_pipeline.centroids_3d = centroids_3d + + return {} + + +class SpineMetricsSaver(InferenceClass): + """Save metrics to a CSV file.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + """Save metrics to a CSV file.""" + self.spine_hus = inference_pipeline.spine_hus + self.seg_hus = inference_pipeline.segmentation_hus + self.output_dir = inference_pipeline.output_dir + self.csv_output_dir = os.path.join(self.output_dir, "metrics") + if not os.path.exists(self.csv_output_dir): + os.makedirs(self.csv_output_dir, exist_ok=True) + self.save_results() + if hasattr(inference_pipeline, "dicom_ds"): + if not os.path.exists(os.path.join(self.output_dir, "dicom_metadata.csv")): + io_utils.write_dicom_metadata_to_csv( + inference_pipeline.dicom_ds, + os.path.join(self.output_dir, "dicom_metadata.csv"), + ) + + return {} + + def save_results(self): + """Save results to a CSV file.""" + df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"]) + for i, level in enumerate(self.spine_hus): + hu = self.spine_hus[level] + seg_hu = self.seg_hus[level] + row = [level, hu, seg_hu] + df.loc[i] = row + df = df.iloc[::-1] + df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False) + + +class SpineFindDicoms(InferenceClass): + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + inferior_superior_centers = spine_utils.find_spine_dicoms( + inference_pipeline.centroids_3d, + ) + + spine_utils.save_nifti_select_slices( + inference_pipeline.output_dir, inferior_superior_centers + ) + inference_pipeline.dicom_file_paths = [ + str(center) for center in inferior_superior_centers + ] + inference_pipeline.names = list(inference_pipeline.rois.keys()) + inference_pipeline.dicom_file_names = list(inference_pipeline.rois.keys()) + inference_pipeline.inferior_superior_centers = inferior_superior_centers + + return {} + + +class SpineCoronalSagittalVisualizer(InferenceClass): + def __init__(self, format="png"): + super().__init__() + self.format = format + + def __call__(self, inference_pipeline): + output_path = inference_pipeline.output_dir + spine_model_type = inference_pipeline.spine_model_type + + img_sagittal, img_coronal = spine_utils.visualize_coronal_sagittal_spine( + inference_pipeline.segmentation.get_fdata(), + list(inference_pipeline.rois.values()), + inference_pipeline.medical_volume.get_fdata(), + list(inference_pipeline.centroids_3d.values()), + output_path, + spine_hus=inference_pipeline.spine_hus, + seg_hus=inference_pipeline.segmentation_hus, + model_type=spine_model_type, + pixel_spacing=inference_pipeline.pixel_spacing_list, + format=self.format, + ) + inference_pipeline.spine_vis_sagittal = img_sagittal + inference_pipeline.spine_vis_coronal = img_coronal + inference_pipeline.spine = True + if not inference_pipeline.save_segmentations: + shutil.rmtree(os.path.join(output_path, "segmentations")) + return {} + + +class SpineReport(InferenceClass): + def __init__(self, format="png"): + super().__init__() + self.format = format + + def __call__(self, inference_pipeline): + sagittal_image = inference_pipeline.spine_vis_sagittal + coronal_image = inference_pipeline.spine_vis_coronal + # concatenate these numpy arrays laterally + img = np.concatenate((coronal_image, sagittal_image), axis=1) + output_path = os.path.join( + inference_pipeline.output_dir, "images", "spine_report" + ) + if self.format == "png": + im = Image.fromarray(img) + im.save(output_path + ".png") + elif self.format == "dcm": + to_dicom(img, output_path + ".dcm") + return {} + + +class SpineMuscleAdiposeTissueReport(InferenceClass): + """Spine muscle adipose tissue report class.""" + + def __init__(self): + super().__init__() + self.image_files = [ + "spine_coronal.png", + "spine_sagittal.png", + "T12.png", + "L1.png", + "L2.png", + "L3.png", + "L4.png", + "L5.png", + ] + + def __call__(self, inference_pipeline): + image_dir = Path(inference_pipeline.output_dir) / "images" + self.generate_panel(image_dir) + return {} + + def generate_panel(self, image_dir: Union[str, Path]): + """Generate panel. + Args: + image_dir (Union[str, Path]): Path to the image directory. + """ + image_files = [os.path.join(image_dir, path) for path in self.image_files] + # construct a list which includes only the images that exist + image_files = [path for path in image_files if os.path.exists(path)] + + im_cor = Image.open(image_files[0]) + im_sag = Image.open(image_files[1]) + im_cor_width = int(im_cor.width / im_cor.height * 512) + num_muscle_fat_cols = math.ceil((len(image_files) - 2) / 2) + width = (8 + im_cor_width + 8) + ((512 + 8) * num_muscle_fat_cols) + height = 1048 + new_im = Image.new("RGB", (width, height)) + + index = 2 + for j in range(8, height, 520): + for i in range(8 + im_cor_width + 8, width, 520): + try: + im = Image.open(image_files[index]) + im.thumbnail((512, 512)) + new_im.paste(im, (i, j)) + index += 1 + im.close() + except Exception: + continue + + im_cor.thumbnail((im_cor_width, 512)) + new_im.paste(im_cor, (8, 8)) + im_sag.thumbnail((im_cor_width, 512)) + new_im.paste(im_sag, (8, 528)) + new_im.save(os.path.join(image_dir, "spine_muscle_adipose_tissue_report.png")) + im_cor.close() + im_sag.close() + new_im.close() diff --git a/Comp2Comp-main/comp2comp/spine/spine_utils.py b/Comp2Comp-main/comp2comp/spine/spine_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb9f1a05dbb693cdcdc7453f4172d751f7086c0 --- /dev/null +++ b/Comp2Comp-main/comp2comp/spine/spine_utils.py @@ -0,0 +1,737 @@ +""" +@author: louisblankemeier +""" + +import logging +import math +import os +from typing import Dict, List + +import cv2 +import nibabel as nib +import numpy as np +from scipy.ndimage import zoom + +from comp2comp.spine import spine_visualization + + +def find_spine_dicoms(centroids: Dict): # , path: str, levels): + """Find the dicom files corresponding to the spine T12 - L5 levels.""" + + vertical_positions = [] + for level in centroids: + centroid = centroids[level] + vertical_positions.append(round(centroid[2])) + + # dicom_files = [] + # ipps = [] + # for dicom_path in glob(path + "/*.dcm"): + # ipp = dcmread(dicom_path).ImagePositionPatient + # ipps.append(ipp[2]) + # dicom_files.append(dicom_path) + + # dicom_files = [x for _, x in sorted(zip(ipps, dicom_files))] + # dicom_files = list(np.array(dicom_files)[vertical_positions]) + + # return (dicom_files, levels, vertical_positions) + return vertical_positions + + +def save_nifti_select_slices(output_dir: str, vertical_positions): + nifti_path = os.path.join(output_dir, "segmentations", "converted_dcm.nii.gz") + nifti_in = nib.load(nifti_path) + nifti_np = nifti_in.get_fdata() + nifti_np = nifti_np[:, :, vertical_positions] + nifti_out = nib.Nifti1Image(nifti_np, nifti_in.affine, nifti_in.header) + # save the nifti + nifti_output_path = os.path.join( + output_dir, "segmentations", "converted_dcm.nii.gz" + ) + nib.save(nifti_out, nifti_output_path) + + +# Function that takes a numpy array as input, computes the +# sagittal centroid of each label and returns a list of the +# centroids +def compute_centroids(seg: np.ndarray, spine_model_type): + """Compute the centroids of the labels. + + Args: + seg (np.ndarray): Segmentation volume. + spine_model_type (str): Model type. + + Returns: + List[int]: List of centroids. + """ + # take values of spine_model_type.categories dictionary + # and convert to list + centroids = {} + for level in spine_model_type.categories: + label_idx = spine_model_type.categories[level] + try: + pos = compute_centroid(seg, "sagittal", label_idx) + centroids[level] = pos + except Exception: + logging.warning(f"Label {level} not found in segmentation volume.") + return centroids + + +# Function that takes a numpy array as input, as well as a list of centroids, +# takes a slice through the centroid on axis = 1 for each centroid +# and returns a list of the slices +def get_slices(seg: np.ndarray, centroids: Dict, spine_model_type): + """Get the slices corresponding to the centroids. + + Args: + seg (np.ndarray): Segmentation volume. + centroids (List[int]): List of centroids. + spine_model_type (str): Model type. + + Returns: + List[np.ndarray]: List of slices. + """ + seg = seg.astype(np.uint8) + slices = {} + for level in centroids: + label_idx = spine_model_type.categories[level] + binary_seg = (seg[centroids[level], :, :] == label_idx).astype(int) + if ( + np.sum(binary_seg) > 200 + ): # heuristic to make sure enough of the body is showing + slices[level] = binary_seg + return slices + + +# Function that takes a mask and for each deletes the right most +# connected component. Returns the mask with the right most +# connected component deleted +def delete_right_most_connected_component(mask: np.ndarray): + """Delete the right most connected component corresponding to spinous processes. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Mask volume. + """ + mask = mask.astype(np.uint8) + _, labels, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) + right_most_connected_component = np.argmin(centroids[1:, 1]) + 1 + mask[labels == right_most_connected_component] = 0 + return mask + + +# compute center of mass of 2d mask +def compute_center_of_mass(mask: np.ndarray): + """Compute the center of mass of a 2D mask. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Center of mass. + """ + mask = mask.astype(np.uint8) + _, _, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) + center_of_mass = np.mean(centroids[1:, :], axis=0) + return center_of_mass + + +# Function that takes a 3d centroid and retruns a binary mask with a 3d +# roi around the centroid +def roi_from_mask(img, centroid: np.ndarray): + """Compute a 3D ROI from a 3D mask. + + Args: + img (np.ndarray): Image volume. + centroid (np.ndarray): Centroid. + + Returns: + np.ndarray: ROI volume. + """ + roi = np.zeros(img.shape) + + img_np = img.get_fdata() + + pixel_spacing = img.header.get_zooms() + length_i = 5.0 / pixel_spacing[0] + length_j = 5.0 / pixel_spacing[1] + length_k = 5.0 / pixel_spacing[2] + + print( + f"Computing ROI with centroid {centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f} " + f"and pixel spacing " + f"{pixel_spacing[0]:.3f}mm, {pixel_spacing[1]:.3f}mm, {pixel_spacing[2]:.3f}mm..." + ) + + # cubic ROI around centroid + """ + roi[ + int(centroid[0] - length) : int(centroid[0] + length), + int(centroid[1] - length) : int(centroid[1] + length), + int(centroid[2] - length) : int(centroid[2] + length), + ] = 1 + """ + # spherical ROI around centroid + roi = np.zeros(img_np.shape) + i_lower = math.floor(centroid[0] - length_i) + j_lower = math.floor(centroid[1] - length_j) + k_lower = math.floor(centroid[2] - length_k) + i_lower_idx = 1000 + j_lower_idx = 1000 + k_lower_idx = 1000 + i_upper_idx = 0 + j_upper_idx = 0 + k_upper_idx = 0 + found_pixels = False + for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1): + for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1): + for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1): + if (i - centroid[0]) ** 2 / length_i**2 + ( + j - centroid[1] + ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1: + roi[i, j, k] = 1 + if i < i_lower_idx: + i_lower_idx = i + if j < j_lower_idx: + j_lower_idx = j + if k < k_lower_idx: + k_lower_idx = k + if i > i_upper_idx: + i_upper_idx = i + if j > j_upper_idx: + j_upper_idx = j + if k > k_upper_idx: + k_upper_idx = k + found_pixels = True + if not found_pixels: + print("No pixels in ROI!") + raise ValueError + print( + f"Number of pixels included in i, j, and k directions: {i_upper_idx - i_lower_idx + 1}, " + f"{j_upper_idx - j_lower_idx + 1}, {k_upper_idx - k_lower_idx + 1}" + ) + return roi + + +# Function that takes a 3d image and a 3d binary mask and returns that average +# value of the image inside the mask +def mean_img_mask(img: np.ndarray, mask: np.ndarray, index: int): + """Compute the mean of an image inside a mask. + + Args: + img (np.ndarray): Image volume. + mask (np.ndarray): Mask volume. + rescale_slope (float): Rescale slope. + rescale_intercept (float): Rescale intercept. + + Returns: + float: Mean value. + """ + img = img.astype(np.float32) + mask = mask.astype(np.float32) + img_masked = (img * mask)[mask > 0] + # mean = (rescale_slope * np.mean(img_masked)) + rescale_intercept + # median = (rescale_slope * np.median(img_masked)) + rescale_intercept + mean = np.mean(img_masked) + return mean + + +def compute_rois(seg, img, spine_model_type): + """Compute the ROIs for the spine. + + Args: + seg (np.ndarray): Segmentation volume. + img (np.ndarray): Image volume. + rescale_slope (float): Rescale slope. + rescale_intercept (float): Rescale intercept. + spine_model_type (Models): Model type. + + Returns: + spine_hus (List[float]): List of HU values. + rois (List[np.ndarray]): List of ROIs. + centroids_3d (List[np.ndarray]): List of centroids. + """ + seg_np = seg.get_fdata() + centroids = compute_centroids(seg_np, spine_model_type) + slices = get_slices(seg_np, centroids, spine_model_type) + for level in slices: + slice = slices[level] + # keep only the two largest connected components + two_largest, two = keep_two_largest_connected_components(slice) + if two: + slices[level] = delete_right_most_connected_component(two_largest) + + # Compute ROIs + rois = {} + spine_hus = {} + centroids_3d = {} + segmentation_hus = {} + for i, level in enumerate(slices): + slice = slices[level] + center_of_mass = compute_center_of_mass(slice) + centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]]) + roi = roi_from_mask(img, centroid) + image_numpy = img.get_fdata() + spine_hus[level] = mean_img_mask(image_numpy, roi, i) + rois[level] = roi + mask = (seg_np == spine_model_type.categories[level]).astype(int) + segmentation_hus[level] = mean_img_mask(image_numpy, mask, i) + centroids_3d[level] = centroid + return (spine_hus, rois, segmentation_hus, centroids_3d) + + +def keep_two_largest_connected_components(mask: Dict): + """Keep the two largest connected components. + + Args: + mask (np.ndarray): Mask volume. + + Returns: + np.ndarray: Mask volume. + """ + mask = mask.astype(np.uint8) + # sort connected components by size + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + stats = stats[1:, 4] + sorted_indices = np.argsort(stats)[::-1] + # keep only the two largest connected components + mask = np.zeros(mask.shape) + mask[labels == sorted_indices[0] + 1] = 1 + two = True + try: + mask[labels == sorted_indices[1] + 1] = 1 + except Exception: + two = False + return (mask, two) + + +def compute_centroid(seg: np.ndarray, plane: str, label: int): + """Compute the centroid of a label in a given plane. + + Args: + seg (np.ndarray): Segmentation volume. + plane (str): Plane. + label (int): Label. + + Returns: + int: Centroid. + """ + if plane == "axial": + sum_out_axes = (0, 1) + sum_axis = 2 + elif plane == "sagittal": + sum_out_axes = (1, 2) + sum_axis = 0 + elif plane == "coronal": + sum_out_axes = (0, 2) + sum_axis = 1 + sums = np.sum(seg == label, axis=sum_out_axes) + normalized_sums = sums / np.sum(sums) + pos = int(np.sum(np.arange(0, seg.shape[sum_axis]) * normalized_sums)) + return pos + + +def to_one_hot(label: np.ndarray, model_type, spine_hus): + """Convert a label to one-hot encoding. + + Args: + label (np.ndarray): Label volume. + model_type (Models): Model type. + + Returns: + np.ndarray: One-hot encoding volume. + """ + levels = list(spine_hus.keys()) + levels.reverse() + one_hot_label = np.zeros((label.shape[0], label.shape[1], len(levels))) + for i, level in enumerate(levels): + label_idx = model_type.categories[level] + one_hot_label[:, :, i] = (label == label_idx).astype(int) + return one_hot_label + + +def visualize_coronal_sagittal_spine( + seg: np.ndarray, + rois: List[np.ndarray], + mvs: np.ndarray, + centroids_3d: np.ndarray, + output_dir: str, + spine_hus=None, + seg_hus=None, + model_type=None, + pixel_spacing=None, + format="png", +): + """Visualize the coronal and sagittal planes of the spine. + + Args: + seg (np.ndarray): Segmentation volume. + rois (List[np.ndarray]): List of ROIs. + mvs (dm.MedicalVolume): Medical volume. + centroids (List[int]): List of centroids. + label_text (List[str]): List of labels. + output_dir (str): Output directory. + spine_hus (List[float], optional): List of HU values. Defaults to None. + model_type (Models, optional): Model type. Defaults to None. + """ + + sagittal_vals, coronal_vals = curved_planar_reformation(mvs, centroids_3d) + zoom_factor = pixel_spacing[2] / pixel_spacing[1] + + sagittal_image = mvs[sagittal_vals, :, range(len(sagittal_vals))] + sagittal_label = seg[sagittal_vals, :, range(len(sagittal_vals))] + sagittal_image = zoom(sagittal_image, (zoom_factor, 1), order=3) + sagittal_label = zoom(sagittal_label, (zoom_factor, 1), order=1).round() + + one_hot_sag_label = to_one_hot(sagittal_label, model_type, spine_hus) + for roi in rois: + one_hot_roi_label = roi[sagittal_vals, :, range(len(sagittal_vals))] + one_hot_roi_label = zoom(one_hot_roi_label, (zoom_factor, 1), order=1).round() + one_hot_sag_label = np.concatenate( + ( + one_hot_sag_label, + one_hot_roi_label.reshape( + (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1) + ), + ), + axis=2, + ) + + coronal_image = mvs[:, coronal_vals, range(len(coronal_vals))] + coronal_label = seg[:, coronal_vals, range(len(coronal_vals))] + coronal_image = zoom(coronal_image, (1, zoom_factor), order=3) + coronal_label = zoom(coronal_label, (1, zoom_factor), order=1).round() + + # coronal_image = zoom(coronal_image, (zoom_factor, 1), order=3) + # coronal_label = zoom(coronal_label, (zoom_factor, 1), order=0).astype(int) + + one_hot_cor_label = to_one_hot(coronal_label, model_type, spine_hus) + for roi in rois: + one_hot_roi_label = roi[:, coronal_vals, range(len(coronal_vals))] + one_hot_roi_label = zoom(one_hot_roi_label, (1, zoom_factor), order=1).round() + one_hot_cor_label = np.concatenate( + ( + one_hot_cor_label, + one_hot_roi_label.reshape( + (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1) + ), + ), + axis=2, + ) + + # flip both axes of coronal image + sagittal_image = np.flip(sagittal_image, axis=0) + sagittal_image = np.flip(sagittal_image, axis=1) + + # flip both axes of coronal label + one_hot_sag_label = np.flip(one_hot_sag_label, axis=0) + one_hot_sag_label = np.flip(one_hot_sag_label, axis=1) + + coronal_image = np.transpose(coronal_image) + one_hot_cor_label = np.transpose(one_hot_cor_label, (1, 0, 2)) + + # flip both axes of sagittal image + coronal_image = np.flip(coronal_image, axis=0) + coronal_image = np.flip(coronal_image, axis=1) + + # flip both axes of sagittal label + one_hot_cor_label = np.flip(one_hot_cor_label, axis=0) + one_hot_cor_label = np.flip(one_hot_cor_label, axis=1) + + if format == "png": + sagittal_name = "spine_sagittal.png" + coronal_name = "spine_coronal.png" + elif format == "dcm": + sagittal_name = "spine_sagittal.dcm" + coronal_name = "spine_coronal.dcm" + else: + raise ValueError("Format must be either png or dcm") + + img_sagittal = spine_visualization.spine_binary_segmentation_overlay( + sagittal_image, + one_hot_sag_label, + output_dir, + sagittal_name, + spine_hus=spine_hus, + seg_hus=seg_hus, + model_type=model_type, + pixel_spacing=pixel_spacing, + ) + img_coronal = spine_visualization.spine_binary_segmentation_overlay( + coronal_image, + one_hot_cor_label, + output_dir, + coronal_name, + spine_hus=spine_hus, + seg_hus=seg_hus, + model_type=model_type, + pixel_spacing=pixel_spacing, + ) + + return img_sagittal, img_coronal + + +def curved_planar_reformation(mvs, centroids): + centroids = sorted(centroids, key=lambda x: x[2]) + centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids] + sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))] + coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))] + axial_centroids = [centroids[i][2] for i in range(0, len(centroids))] + sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0] + coronal_vals = [coronal_centroids[0]] * axial_centroids[0] + + for i in range(1, len(axial_centroids)): + num = axial_centroids[i] - axial_centroids[i - 1] + interp = list( + np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num) + ) + sagittal_vals.extend(interp) + interp = list( + np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num) + ) + coronal_vals.extend(interp) + + sagittal_vals.extend([sagittal_centroids[-1]] * (mvs.shape[2] - len(sagittal_vals))) + coronal_vals.extend([coronal_centroids[-1]] * (mvs.shape[2] - len(coronal_vals))) + sagittal_vals = np.array(sagittal_vals) + coronal_vals = np.array(coronal_vals) + sagittal_vals = sagittal_vals.astype(int) + coronal_vals = coronal_vals.astype(int) + + return (sagittal_vals, coronal_vals) + + +''' +def compare_ts_stanford_centroids(labels_path, pred_centroids): + """Compare the centroids of the Stanford dataset with the centroids of the TS dataset. + + Args: + labels_path (str): Path to the Stanford dataset labels. + """ + t12_diff = [] + l1_diff = [] + l2_diff = [] + l3_diff = [] + l4_diff = [] + l5_diff = [] + num_skipped = 0 + + labels = glob(labels_path + "/*") + for label_path in labels: + # modify label_path to give pred_path + pred_path = label_path.replace("labelsTs", "predTs_TS") + print(label_path.split("/")[-1]) + label_nib = nib.load(label_path) + label = label_nib.get_fdata() + spacing = label_nib.header.get_zooms()[2] + pred_nib = nib.load(pred_path) + pred = pred_nib.get_fdata() + if True: + pred[pred == 18] = 6 + pred[pred == 19] = 5 + pred[pred == 20] = 4 + pred[pred == 21] = 3 + pred[pred == 22] = 2 + pred[pred == 23] = 1 + + for label_idx in range(1, 7): + label_level = label == label_idx + indexes = np.array(range(label.shape[2])) + sums = np.sum(label_level, axis=(0, 1)) + normalized_sums = sums / np.sum(sums) + label_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for label {label_idx}: {label_centroid}") + + if False: + try: + pred_centroid = pred_centroids[6 - label_idx] + except Exception: + # Change this part + print("Something wrong with pred_centroids, skipping!") + num_skipped += 1 + break + + # if revert_to_original: + if True: + pred_level = pred == label_idx + sums = np.sum(pred_level, axis=(0, 1)) + indices = list(range(sums.shape[0])) + groupby_input = zip(indices, list(sums)) + g = groupby(groupby_input, key=lambda x: x[1] > 0.0) + m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1])) + res = list(zip(*m)) + indexes = list(res[0]) + sums = list(res[1]) + normalized_sums = sums / np.sum(sums) + pred_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for prediction {label_idx}: {pred_centroid}") + + diff = np.absolute(pred_centroid - label_centroid) * spacing + + if label_idx == 1: + t12_diff.append(diff) + elif label_idx == 2: + l1_diff.append(diff) + elif label_idx == 3: + l2_diff.append(diff) + elif label_idx == 4: + l3_diff.append(diff) + elif label_idx == 5: + l4_diff.append(diff) + elif label_idx == 6: + l5_diff.append(diff) + + print(f"Skipped {num_skipped}") + print("The final mean differences in mm:") + print( + np.mean(t12_diff), + np.mean(l1_diff), + np.mean(l2_diff), + np.mean(l3_diff), + np.mean(l4_diff), + np.mean(l5_diff), + ) + print("The final median differences in mm:") + print( + np.median(t12_diff), + np.median(l1_diff), + np.median(l2_diff), + np.median(l3_diff), + np.median(l4_diff), + np.median(l5_diff), + ) + + +def compare_ts_stanford_roi_hus(image_path): + """Compare the HU values of the Stanford dataset with the HU values of the TS dataset. + + image_path (str): Path to the Stanford dataset images. + """ + img_paths = glob(image_path + "/*") + differences = np.zeros((40, 6)) + ground_truth = np.zeros((40, 6)) + for i, img_path in enumerate(img_paths): + print(f"Image number {i + 1}") + image_path_no_0000 = re.sub(r"_0000", "", img_path) + ts_seg_path = image_path_no_0000.replace("imagesTs", "predTs_TS") + stanford_seg_path = image_path_no_0000.replace("imagesTs", "labelsTs") + img = nib.load(img_path).get_fdata() + img = np.swapaxes(img, 0, 1) + ts_seg = nib.load(ts_seg_path).get_fdata() + ts_seg = np.swapaxes(ts_seg, 0, 1) + stanford_seg = nib.load(stanford_seg_path).get_fdata() + stanford_seg = np.swapaxes(stanford_seg, 0, 1) + ts_model_type = Models.model_from_name("ts_spine") + (spine_hus_ts, rois, centroids_3d) = compute_rois(ts_seg, img, 1, 0, ts_model_type) + stanford_model_type = Models.model_from_name("stanford_spine_v0.0.1") + (spine_hus_stanford, rois, centroids_3d) = compute_rois( + stanford_seg, img, 1, 0, stanford_model_type + ) + difference_vals = np.abs(np.array(spine_hus_ts) - np.array(spine_hus_stanford)) + print(f"Differences {difference_vals}\n") + differences[i, :] = difference_vals + ground_truth[i, :] = spine_hus_stanford + print("\n") + # compute average percent change from ground truth + percent_change = np.divide(differences, ground_truth) * 100 + average_percent_change = np.mean(percent_change, axis=0) + median_percent_change = np.median(percent_change, axis=0) + # print average percent change + print("Average percent change from ground truth:") + print(average_percent_change) + print("Median percent change from ground truth:") + print(median_percent_change) + # print average difference + average_difference = np.mean(differences, axis=0) + median_difference = np.median(differences, axis=0) + print("Average difference from ground truth:") + print(average_difference) + print("Median difference from ground truth:") + print(median_difference) + + +def process_post_hoc(pred_path): + """Apply post-hoc heuristics for improving Stanford spine model vertical centroid predictions. + + Args: + pred_path (str): Path to the prediction. + """ + pred_nib = nib.load(pred_path) + pred = pred_nib.get_fdata() + + pred_bodies = np.logical_and(pred >= 1, pred <= 6) + pred_bodies = pred_bodies.astype(np.int64) + + labels_out, N = cc3d.connected_components(pred_bodies, return_N=True, connectivity=6) + + stats = cc3d.statistics(labels_out) + print(stats) + + labels_out_list = [] + voxel_counts_list = list(stats["voxel_counts"]) + for idx_lab in range(1, N + 2): + labels_out_list.append(labels_out == idx_lab) + + centroids_list = list(stats["centroids"][:, 2]) + + labels = [] + centroids = [] + voxels = [] + + for idx, count in enumerate(voxel_counts_list): + if count > 10000: + labels.append(labels_out_list[idx]) + centroids.append(centroids_list[idx]) + voxels.append(count) + + top_comps = [ + (counts0, labels0, centroids0) + for counts0, labels0, centroids0 in sorted(zip(voxels, labels, centroids), reverse=True) + ] + top_comps = top_comps[1:7] + + # ====== Check whether the connected components are fusing vertebral bodies ====== + revert_to_original = False + + volumes = list(zip(*top_comps))[0] + if volumes[0] > 1.5 * volumes[1]: + revert_to_original = True + print("Reverting to original...") + + labels = list(zip(*top_comps))[1] + centroids = list(zip(*top_comps))[2] + + top_comps = zip(centroids, labels) + pred_centroids = [x for x, _ in sorted(top_comps)] + + for label_idx in range(1, 7): + if not revert_to_original: + try: + pred_centroid = pred_centroids[6 - label_idx] + except: + # Change this part + print( + "Post processing failure, probably < 6 predicted bodies. Reverting to original labels." + ) + revert_to_original = True + + if revert_to_original: + pred_level = pred == label_idx + sums = np.sum(pred_level, axis=(0, 1)) + indices = list(range(sums.shape[0])) + groupby_input = zip(indices, list(sums)) + # sys.exit() + g = groupby(groupby_input, key=lambda x: x[1] > 0.0) + m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1])) + # sys.exit() + # m = max([list(s) for v, s in g], key=lambda np.sum) + res = list(zip(*m)) + indexes = list(res[0]) + sums = list(res[1]) + normalized_sums = sums / np.sum(sums) + pred_centroid = np.sum(indexes * normalized_sums) + print(f"Centroid for prediction {label_idx}: {pred_centroid}") +''' diff --git a/Comp2Comp-main/comp2comp/spine/spine_visualization.py b/Comp2Comp-main/comp2comp/spine/spine_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..de2be7a2447e84e7dda14108f942fc9bf4219b36 --- /dev/null +++ b/Comp2Comp-main/comp2comp/spine/spine_visualization.py @@ -0,0 +1,198 @@ +""" +@author: louisblankemeier +""" + +import os +from pathlib import Path +from typing import Union + +import numpy as np + +from comp2comp.visualization.detectron_visualizer import Visualizer + + +def spine_binary_segmentation_overlay( + img_in: Union[str, Path], + mask: Union[str, Path], + base_path: Union[str, Path], + file_name: str, + figure_text_key=None, + spine_hus=None, + seg_hus=None, + spine=True, + model_type=None, + pixel_spacing=None, +): + """Save binary segmentation overlay. + Args: + img_in (Union[str, Path]): Path to the input image. + mask (Union[str, Path]): Path to the mask. + base_path (Union[str, Path]): Path to the output directory. + file_name (str): Output file name. + centroids (list, optional): List of centroids. Defaults to None. + figure_text_key (dict, optional): Figure text key. Defaults to None. + spine_hus (list, optional): List of HU values. Defaults to None. + spine (bool, optional): Spine flag. Defaults to True. + model_type (Models): Model type. Defaults to None. + """ + _COLORS = ( + np.array( + [ + 1.000, + 0.000, + 0.000, + 0.000, + 1.000, + 0.000, + 1.000, + 1.000, + 0.000, + 1.000, + 0.500, + 0.000, + 0.000, + 1.000, + 1.000, + 1.000, + 0.000, + 1.000, + ] + ) + .astype(np.float32) + .reshape(-1, 3) + ) + + label_map = {"L5": 0, "L4": 1, "L3": 2, "L2": 3, "L1": 4, "T12": 5} + + _ROI_COLOR = np.array([1.000, 0.340, 0.200]) + + _SPINE_TEXT_OFFSET_FROM_TOP = 10.0 + _SPINE_TEXT_OFFSET_FROM_RIGHT = 40.0 + _SPINE_TEXT_VERTICAL_SPACING = 14.0 + + img_in = np.clip(img_in, -300, 1800) + img_in = normalize_img(img_in) * 255.0 + images_base_path = Path(base_path) / "images" + images_base_path.mkdir(exist_ok=True) + + img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1)) + img_rgb = np.tile(img_in, (1, 1, 3)) + + vis = Visualizer(img_rgb) + + levels = list(spine_hus.keys()) + levels.reverse() + num_levels = len(levels) + + # draw seg masks + for i, level in enumerate(levels): + color = _COLORS[label_map[level]] + edge_color = None + alpha_val = 0.2 + vis.draw_binary_mask( + mask[:, :, i].astype(int), + color=color, + edge_color=edge_color, + alpha=alpha_val, + area_threshold=0, + ) + + # draw rois + for i, _ in enumerate(levels): + color = _ROI_COLOR + edge_color = color + vis.draw_binary_mask( + mask[:, :, num_levels + i].astype(int), + color=color, + edge_color=edge_color, + alpha=alpha_val, + area_threshold=0, + ) + + vis.draw_text( + text="ROI", + position=( + mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35, + _SPINE_TEXT_OFFSET_FROM_TOP, + ), + color=[1, 1, 1], + font_size=9, + horizontal_alignment="center", + ) + + vis.draw_text( + text="Seg", + position=( + mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT, + _SPINE_TEXT_OFFSET_FROM_TOP, + ), + color=[1, 1, 1], + font_size=9, + horizontal_alignment="center", + ) + + # draw text and lines + for i, level in enumerate(levels): + vis.draw_text( + text=f"{level}:", + position=( + mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 80, + _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, + ), + color=_COLORS[label_map[level]], + font_size=9, + horizontal_alignment="left", + ) + vis.draw_text( + text=f"{round(float(spine_hus[level]))}", + position=( + mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35, + _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, + ), + color=_COLORS[label_map[level]], + font_size=9, + horizontal_alignment="center", + ) + vis.draw_text( + text=f"{round(float(seg_hus[level]))}", + position=( + mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT, + _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, + ), + color=_COLORS[label_map[level]], + font_size=9, + horizontal_alignment="center", + ) + + """ + vis.draw_line( + x_data=(0, mask.shape[1] - 1), + y_data=( + int( + inferior_superior_centers[num_levels - i - 1] + * (pixel_spacing[2] / pixel_spacing[1]) + ), + int( + inferior_superior_centers[num_levels - i - 1] + * (pixel_spacing[2] / pixel_spacing[1]) + ), + ), + color=_COLORS[label_map[level]], + linestyle="dashed", + linewidth=0.25, + ) + """ + + vis_obj = vis.get_output() + img = vis_obj.save(os.path.join(images_base_path, file_name)) + return img + + +def normalize_img(img: np.ndarray) -> np.ndarray: + """Normalize the image. + Args: + img (np.ndarray): Input image. + Returns: + np.ndarray: Normalized image. + """ + return (img - img.min()) / (img.max() - img.min()) diff --git a/Comp2Comp-main/comp2comp/utils/__init__.py b/Comp2Comp-main/comp2comp/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Comp2Comp-main/comp2comp/utils/colormap.py b/Comp2Comp-main/comp2comp/utils/colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..05ef8684359af1e97e790d8c501375a512eec549 --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/colormap.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +An awesome colormap for really neat visualizations. +Copied from Detectron, and removed gray colors. +""" + +import random + +import numpy as np + +__all__ = ["colormap", "random_color", "random_colors"] + +# fmt: off +# RGB: +_COLORS = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.857, 0.857, 0.857, + 1.000, 1.000, 1.000 + ] +).astype(np.float32).reshape(-1, 3) +# fmt: on + + +def colormap(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + Returns: + ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] + """ + assert maximum in [255, 1], maximum + c = _COLORS * maximum + if not rgb: + c = c[:, ::-1] + return c + + +def random_color(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + Returns: + ndarray: a vector of 3 numbers + """ + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not rgb: + ret = ret[::-1] + return ret + + +def random_colors(N, rgb=False, maximum=255): + """ + Args: + N (int): number of unique colors needed + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + Returns: + ndarray: a list of random_color + """ + indices = random.sample(range(len(_COLORS)), N) + ret = [_COLORS[i] * maximum for i in indices] + if not rgb: + ret = [x[::-1] for x in ret] + return ret + + +if __name__ == "__main__": + import cv2 + + size = 100 + H, W = 10, 10 + canvas = np.random.rand(H * size, W * size, 3).astype("float32") + for h in range(H): + for w in range(W): + idx = h * W + w + if idx >= len(_COLORS): + break + canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] + cv2.imshow("a", canvas) + cv2.waitKey(0) diff --git a/Comp2Comp-main/comp2comp/utils/dl_utils.py b/Comp2Comp-main/comp2comp/utils/dl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a76d343f4845f0cf39adb9f18db42e20b9e33a4 --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/dl_utils.py @@ -0,0 +1,80 @@ +import subprocess + +from keras import Model + +# from keras.utils import multi_gpu_model +# from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model + + +def get_available_gpus(num_gpus: int = None): + """Get gpu ids for gpus that are >95% free. + + Tensorflow does not support checking free memory on gpus. + This is a crude method that relies on `nvidia-smi` to + determine which gpus are occupied and which are free. + + Args: + num_gpus: Number of requested gpus. If not specified, + ids of all available gpu(s) are returned. + + Returns: + List[int]: List of gpu ids that are free. Length + will equal `num_gpus`, if specified. + """ + # Built-in tensorflow gpu id. + assert isinstance(num_gpus, (type(None), int)) + if num_gpus == 0: + return [-1] + + num_requested_gpus = num_gpus + try: + num_gpus = ( + len( + subprocess.check_output("nvidia-smi --list-gpus", shell=True) + .decode() + .split("\n") + ) + - 1 + ) + + out_str = subprocess.check_output("nvidia-smi | grep MiB", shell=True).decode() + except subprocess.CalledProcessError: + return None + mem_str = [x for x in out_str.split() if "MiB" in x] + # First 2 * num_gpu elements correspond to memory for gpus + # Order: (occupied-0, total-0, occupied-1, total-1, ...) + mems = [float(x[:-3]) for x in mem_str] + gpu_percent_occupied_mem = [ + mems[2 * gpu_id] / mems[2 * gpu_id + 1] for gpu_id in range(num_gpus) + ] + + available_gpus = [ + gpu_id for gpu_id, mem in enumerate(gpu_percent_occupied_mem) if mem < 0.05 + ] + if num_requested_gpus and num_requested_gpus > len(available_gpus): + raise ValueError( + "Requested {} gpus, only {} are free".format( + num_requested_gpus, len(available_gpus) + ) + ) + + return available_gpus[:num_requested_gpus] if num_requested_gpus else available_gpus + + +class ModelMGPU(Model): + """Wrapper for distributing model across multiple gpus""" + + def __init__(self, ser_model, gpus): + pmodel = multi_gpu_model(ser_model, gpus) # noqa: F821 + self.__dict__.update(pmodel.__dict__) + self._smodel = ser_model + + def __getattribute__(self, attrname): + """Override load and save methods to be used from the serial-model. The + serial-model holds references to the weights in the multi-gpu model. + """ + # return Model.__getattribute__(self, attrname) + if "load" in attrname or "save" in attrname: + return getattr(self._smodel, attrname) + + return super(ModelMGPU, self).__getattribute__(attrname) diff --git a/Comp2Comp-main/comp2comp/utils/env.py b/Comp2Comp-main/comp2comp/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf48f46c4108d1bb1f194bc9b0c342abb420c03 --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/env.py @@ -0,0 +1,84 @@ +import importlib +import importlib.util +import os +import sys + +__all__ = [] + + +# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa +def _import_file(module_name, file_path, make_importable=False): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if make_importable: + sys.modules[module_name] = module + return module + + +def _configure_libraries(): + """ + Configurations for some libraries. + """ + # An environment option to disable `import cv2` globally, + # in case it leads to negative performance impact + disable_cv2 = int(os.environ.get("MEDSEGPY_DISABLE_CV2", False)) + if disable_cv2: + sys.modules["cv2"] = None + else: + # Disable opencl in opencv since its interaction with cuda often + # has negative effects + # This envvar is supported after OpenCV 3.4.0 + os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" + try: + import cv2 + + if int(cv2.__version__.split(".")[0]) >= 3: + cv2.ocl.setUseOpenCL(False) + except ImportError: + pass + + +_ENV_SETUP_DONE = False + + +def setup_environment(): + """Perform environment setup work. The default setup is a no-op, but this + function allows the user to specify a Python source file or a module in + the $MEDSEGPY_ENV_MODULE environment variable, that performs + custom setup work that may be necessary to their computing environment. + """ + global _ENV_SETUP_DONE + if _ENV_SETUP_DONE: + return + _ENV_SETUP_DONE = True + + _configure_libraries() + + custom_module_path = os.environ.get("MEDSEGPY_ENV_MODULE") + + if custom_module_path: + setup_custom_environment(custom_module_path) + else: + # The default setup is a no-op + pass + + +def setup_custom_environment(custom_module): + """ + Load custom environment setup by importing a Python source file or a + module, and run the setup function. + """ + if custom_module.endswith(".py"): + module = _import_file("medsegpy.utils.env.custom_module", custom_module) + else: + module = importlib.import_module(custom_module) + assert hasattr(module, "setup_environment") and callable( + module.setup_environment + ), ( + "Custom environment module defined in {} does not have the " + "required callable attribute 'setup_environment'." + ).format( + custom_module + ) + module.setup_environment() diff --git a/Comp2Comp-main/comp2comp/utils/logger.py b/Comp2Comp-main/comp2comp/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bb634d5a332eea3a1978fb110069834f2f6313 --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/logger.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import functools +import logging +import os +import sys +import time +from collections import Counter + +from termcolor import colored + +logging.captureWarnings(True) + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa +def setup_logger( + output=None, + distributed_rank=0, + *, + color=True, + name="Comp2Comp", + abbrev_name=None, +): + """ + Initialize the detectron2 logger and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not + save log file. If ends with ".txt" or ".log", assumed to be a file + name. Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + abbrev_name (str): an abbreviation of the module, to avoid long names in + logs. Set to "" to not log the root module in logs. + By default, will abbreviate "detectron2" to "d2" and leave other + modules unchanged. + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + if abbrev_name is None: + abbrev_name = name + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + ".rank{}".format(distributed_rank) + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + return open(filename, "a") + + +""" +Below are some other convenient logging methods. +They are mainly adopted from +https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py +""" + + +def _find_caller(): + """ + Returns: + str: module name of the caller + tuple: a hashable key to be used to identify different callers + """ + frame = sys._getframe(2) + while frame: + code = frame.f_code + if os.path.join("utils", "logger.") not in code.co_filename: + mod_name = frame.f_globals["__name__"] + if mod_name == "__main__": + mod_name = "detectron2" + return mod_name, (code.co_filename, frame.f_lineno, code.co_name) + frame = frame.f_back + + +_LOG_COUNTER = Counter() +_LOG_TIMER = {} + + +def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): + """ + Log only for the first n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by + default. + key (str or tuple[str]): the string(s) can be one of "caller" or + "message", which defines how to identify duplicated logs. + For example, if called with `n=1, key="caller"`, this function + will only log the first call from the same caller, regardless of + the message content. + If called with `n=1, key="message"`, this function will log the + same content only once, even if they are called from different + places. + If called with `n=1, key=("caller", "message")`, this function + will not log only if the same caller has logged the same message + before. + """ + if isinstance(key, str): + key = (key,) + assert len(key) > 0 + + caller_module, caller_key = _find_caller() + hash_key = () + if "caller" in key: + hash_key = hash_key + caller_key + if "message" in key: + hash_key = hash_key + (msg,) + + _LOG_COUNTER[hash_key] += 1 + if _LOG_COUNTER[hash_key] <= n: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n(lvl, msg, n=1, *, name=None): + """ + Log once per n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by + default. + """ + caller_module, key = _find_caller() + _LOG_COUNTER[key] += 1 + if n == 1 or _LOG_COUNTER[key] % n == 1: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n_seconds(lvl, msg, n=1, *, name=None): + """ + Log no more than once per n seconds. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by + default. + """ + caller_module, key = _find_caller() + last_logged = _LOG_TIMER.get(key, None) + current_time = time.time() + if last_logged is None or current_time - last_logged >= n: + logging.getLogger(name or caller_module).log(lvl, msg) + _LOG_TIMER[key] = current_time diff --git a/Comp2Comp-main/comp2comp/utils/orientation.py b/Comp2Comp-main/comp2comp/utils/orientation.py new file mode 100644 index 0000000000000000000000000000000000000000..4481eb85ad05b8bb8f06dfe14f1832fd5a454d54 --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/orientation.py @@ -0,0 +1,30 @@ +import nibabel as nib + +from comp2comp.inference_class_base import InferenceClass + + +class ToCanonical(InferenceClass): + """Convert spine segmentation to canonical orientation.""" + + def __init__(self): + super().__init__() + + def __call__(self, inference_pipeline): + """ + First dim goes from L to R. + Second dim goes from P to A. + Third dim goes from I to S. + """ + canonical_segmentation = nib.as_closest_canonical( + inference_pipeline.segmentation + ) + canonical_medical_volume = nib.as_closest_canonical( + inference_pipeline.medical_volume + ) + + inference_pipeline.segmentation = canonical_segmentation + inference_pipeline.medical_volume = canonical_medical_volume + inference_pipeline.pixel_spacing_list = ( + canonical_medical_volume.header.get_zooms() + ) + return {} diff --git a/Comp2Comp-main/comp2comp/utils/process.py b/Comp2Comp-main/comp2comp/utils/process.py new file mode 100644 index 0000000000000000000000000000000000000000..491f9576580d46b3845cd68f80880fcaa0f7f78c --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/process.py @@ -0,0 +1,120 @@ +""" +@author: louisblankemeier +""" + +import os +import shutil +import sys +import time +import traceback +from datetime import datetime +from pathlib import Path + +from comp2comp.io import io_utils + + +def process_2d(args, pipeline_builder): + output_dir = Path( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../../outputs", + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + ) + if not os.path.exists(output_dir): + output_dir.mkdir(parents=True) + + model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models") + if not os.path.exists(model_dir): + os.mkdir(model_dir) + + pipeline = pipeline_builder(args) + + pipeline(output_dir=output_dir, model_dir=model_dir) + + +def process_3d(args, pipeline_builder): + model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models") + if not os.path.exists(model_dir): + os.mkdir(model_dir) + + if args.output_path is not None: + output_path = Path(args.output_path) + else: + output_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../outputs" + ) + + if not args.overwrite_outputs: + date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + output_path = os.path.join(output_path, date_time) + + for path, num in io_utils.get_dicom_or_nifti_paths_and_num(args.input_path): + try: + st = time.time() + + if path.endswith(".nii") or path.endswith(".nii.gz"): + print("Processing: ", path) + + else: + print("Processing: ", path, " with ", num, " slices") + min_slices = 30 + if num < min_slices: + print(f"Number of slices is less than {min_slices}, skipping\n") + continue + + print("") + + try: + sys.stdout.flush() + except Exception: + pass + + if path.endswith(".nii") or path.endswith(".nii.gz"): + folder_name = Path(os.path.basename(os.path.normpath(path))) + # remove .nii or .nii.gz + folder_name = os.path.normpath( + Path(str(folder_name).replace(".gz", "").replace(".nii", "")) + ) + output_dir = Path( + os.path.join( + output_path, + folder_name, + ) + ) + + else: + output_dir = Path( + os.path.join( + output_path, + Path(os.path.basename(os.path.normpath(args.input_path))), + os.path.relpath( + os.path.normpath(path), os.path.normpath(args.input_path) + ), + ) + ) + + if not os.path.exists(output_dir): + output_dir.mkdir(parents=True) + + pipeline = pipeline_builder(path, args) + + pipeline(output_dir=output_dir, model_dir=model_dir) + + if not args.save_segmentations: + # remove the segmentations folder + segmentations_dir = os.path.join(output_dir, "segmentations") + if os.path.exists(segmentations_dir): + shutil.rmtree(segmentations_dir) + + print(f"Finished processing {path} in {time.time() - st:.1f} seconds\n") + + except Exception: + print(f"ERROR PROCESSING {path}\n") + traceback.print_exc() + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + # remove parent folder if empty + if len(os.listdir(os.path.dirname(output_dir))) == 0: + shutil.rmtree(os.path.dirname(output_dir)) + continue diff --git a/Comp2Comp-main/comp2comp/utils/run.py b/Comp2Comp-main/comp2comp/utils/run.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f6bb84e5984a92a1c46bef03357766a38c7efc --- /dev/null +++ b/Comp2Comp-main/comp2comp/utils/run.py @@ -0,0 +1,126 @@ +import logging +import os +import re +from typing import Sequence, Union + +logger = logging.getLogger(__name__) + + +def format_output_path( + file_path, + save_dir: str = None, + base_dirs: Sequence[str] = None, + file_name: Sequence[str] = None, +): + """Format output path for a given file. + + Args: + file_path (str): File path. + save_dir (str, optional): Save directory. Defaults to None. + base_dirs (Sequence[str], optional): Base directories. Defaults to None. + file_name (Sequence[str], optional): File name. Defaults to None. + + Returns: + str: Output path. + """ + + dirname = os.path.dirname(file_path) if not save_dir else save_dir + + if save_dir and base_dirs: + dirname: str = os.path.dirname(file_path) + relative_dir = [ + dirname.split(bdir, 1)[1] for bdir in base_dirs if dirname.startswith(bdir) + ][0] + # Trim path separator from the path + relative_dir = relative_dir.lstrip(os.path.sep) + dirname = os.path.join(save_dir, relative_dir) + + if file_name is not None: + return os.path.join( + dirname, + "{}.h5".format(file_name), + ) + + return os.path.join( + dirname, + "{}.h5".format(os.path.splitext(os.path.basename(file_path))[0]), + ) + + +# Function the returns a list of file names exluding +# the extention from the list of file paths +def get_file_names(files): + """Get file names from a list of file paths. + + Args: + files (list): List of file paths. + + Returns: + list: List of file names. + """ + file_names = [] + for file in files: + file_name = os.path.splitext(os.path.basename(file))[0] + file_names.append(file_name) + return file_names + + +def find_files( + root_dirs: Union[str, Sequence[str]], + max_depth: int = None, + exist_ok: bool = False, + pattern: str = None, +): + """Recursively search for files. + + To avoid recomputing experiments with results, set `exist_ok=False`. + Results will be searched for in `PREFERENCES.OUTPUT_DIR` (if non-empty). + + Args: + root_dirs (`str(s)`): Root folder(s) to search. + max_depth (int, optional): Maximum depth to search. + exist_ok (bool, optional): If `True`, recompute results for + scans. + pattern (str, optional): If specified, looks for files with names + matching the pattern. + + Return: + List[str]: Experiment directories to test. + """ + + def _get_files(depth: int, dir_name: str): + if dir_name is None or not os.path.isdir(dir_name): + return [] + + if max_depth is not None and depth > max_depth: + return [] + + files = os.listdir(dir_name) + ret_files = [] + for file in files: + possible_dir = os.path.join(dir_name, file) + if os.path.isdir(possible_dir): + subfiles = _get_files(depth + 1, possible_dir) + ret_files.extend(subfiles) + elif os.path.isfile(possible_dir): + if pattern and not re.match(pattern, possible_dir): + continue + output_path = format_output_path(possible_dir) + if not exist_ok and os.path.isfile(output_path): + logger.info( + "Skipping {} - results exist at {}".format( + possible_dir, output_path + ) + ) + continue + ret_files.append(possible_dir) + + return ret_files + + out_files = [] + if isinstance(root_dirs, str): + root_dirs = [root_dirs] + for d in root_dirs: + out_files.extend(_get_files(0, d)) + + return sorted(set(out_files)) diff --git a/Comp2Comp-main/comp2comp/visualization/detectron_visualizer.py b/Comp2Comp-main/comp2comp/visualization/detectron_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d3746d5d3be4781cb6b9372fc8c96f2180a430 --- /dev/null +++ b/Comp2Comp-main/comp2comp/visualization/detectron_visualizer.py @@ -0,0 +1,1288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import colorsys +import logging +import math +from enum import Enum, unique +from pathlib import Path + +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import numpy as np +import pycocotools.mask as mask_util +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg + +from comp2comp.utils.colormap import random_color +from comp2comp.visualization.dicom import to_dicom + +logger = logging.getLogger(__name__) + +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError( + "GenericMask cannot handle object {} of type '{}'".format(m, type(m)) + ) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = ( + False # if original format is polygon, does not have holes + ) + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray( + mask + ) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours( + mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE + ) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = ( + pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + ) + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = ( + segment_ids[sorted_idxs], + areas[sorted_idxs], + ) + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(np.bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(np.bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(np.bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = [ + "{} {:.0f}%".format(lbl, s * 100) for lbl, s in zip(labels, scores) + ] + if labels is not None and is_crowd is not None: + labels = [ + lbl + ("|crowd" if crowd else "") for lbl, crowd in zip(labels, is_crowd) + ] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow( + img, extent=(0, self.width, self.height, 0), interpolation="nearest" + ) + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + # if filepath is a png or jpg + img = self.get_image() + if filepath.endswith(".png") or filepath.endswith(".jpg"): + self.fig.savefig(filepath) + if filepath.endswith(".dcm"): + to_dicom(img, Path(filepath)) + return img + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + # TODO implement a fast, rasterized version using OpenCV + + def __init__( + self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE + ): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + # if metadata is None: + # metadata = MetadataCatalog.get("__nonexist__") + self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 90, 10 // scale + ) + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = ( + predictions.pred_classes.tolist() + if predictions.has("pred_classes") + else None + ) + labels = _create_text_labels( + classes, scores, self.metadata.get("thing_classes", None) + ) + keypoints = ( + predictions.pred_keypoints if predictions.has("pred_keypoints") else None + ) + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = [ + GenericMask(x, self.output.height, self.output.width) for x in masks + ] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get( + "thing_colors" + ): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in classes + ] + alpha = 0.8 + else: + colors = None + alpha = 0.5 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg( + self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7 + ): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = self.metadata.stuff_classes[category_idx] + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + labels = _create_text_labels( + category_ids, + scores, + self.metadata.thing_classes, + [x.get("iscrowd", 0) for x in sinfo], + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances( + masks=masks, labels=labels, assigned_colors=colors, alpha=alpha + ) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + alpha=0.5, + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [ + random_color(rgb=True, maximum=1) for _ in range(num_instances) + ] + if num_instances == 0: + return self.output + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + if areas is not None: + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] if boxes is not None else None + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], edge_color=color) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) + + if labels is not None: + # first get a box + if boxes is not None: + x0, y0, x1, y1 = boxes[i] + text_pos = ( + x0, + y0, + ) # if drawing boxes, put text on the box corner. + horiz_align = "left" + elif masks is not None: + # skip small mask without polygon + if len(masks[i].polygons) == 0: + continue + + x0, y0, x1, y1 = masks[i].bbox() + + # draw text in the center (defined by median) when box is not drawn + # median is less sensitive to outliers. + text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] + horiz_align = "center" + else: + continue # drawing the box confidence for keypoints isn't very useful. + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt( + self.output.height * self.output.width + ) + lighter_color = self._change_color_brightness( + color, brightness_factor=0.7 + ) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + labels[i], + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + return self.output + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [ + random_color(rgb=True, maximum=1) for _ in range(num_instances) + ] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], + edge_color=colors[i], + label=labels[i] if labels is not None else None, + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line( + [nose_x, mid_shoulder_x], + [nose_y, mid_shoulder_y], + color=_RED, + ) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line( + [mid_hip_x, mid_shoulder_x], + [mid_hip_y, mid_shoulder_y], + color=_RED, + ) + return self.output + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={ + "facecolor": "black", + "alpha": 0.8, + "pad": 0.7, + "edgecolor": "none", + }, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 4, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [ + (-w / 2, h / 2), + (-w / 2, -h / 2), + (w / 2, -h / 2), + (w / 2, h / 2), + ] + # x: left->right ; y: top->down + rotated_rect = [ + (s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect + ] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness( + edge_color, brightness_factor=0.7 + ) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + label, + text_pos, + color=label_color, + font_size=font_size, + rotation=angle, + ) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=False, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, + binary_mask, + color=None, + *, + edge_color=None, + text=None, + alpha=0.5, + area_threshold=10, + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area( + mask_util.frPyObjects([segment], shape2d[0], shape2d[1]) + ) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon( + segment, color=color, edge_color=edge_color, alpha=alpha + ) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow( + rgba, extent=(0, self.output.width, self.output.height, 0) + ) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow( + rgba, extent=(0, self.output.width, self.output.height, 0) + ) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is not None: + """ + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + """ + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb( + polygon_color[0], modified_lightness, polygon_color[2] + ) + return modified_color + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats( + binary_mask, 8 + ) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + self.draw_text(text, center, color=color) + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output diff --git a/Comp2Comp-main/comp2comp/visualization/dicom.py b/Comp2Comp-main/comp2comp/visualization/dicom.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd8b4fa29b144719d940ef80473a7e1bc1bd64d --- /dev/null +++ b/Comp2Comp-main/comp2comp/visualization/dicom.py @@ -0,0 +1,73 @@ +import os +from pathlib import Path + +import numpy as np +import pydicom +from PIL import Image +from pydicom.dataset import Dataset, FileMetaDataset +from pydicom.uid import ExplicitVRLittleEndian + + +def to_dicom(input, output_path, plane="axial"): + """Converts a png image to a dicom image. Written with assistance from ChatGPT.""" + if isinstance(input, str) or isinstance(input, Path): + png_path = input + dicom_path = os.path.join( + output_path, os.path.basename(png_path).replace(".png", ".dcm") + ) + image = Image.open(png_path) + image_array = np.array(image) + image_array = image_array[:, :, :3] + else: + image_array = input + dicom_path = output_path + + meta = FileMetaDataset() + meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.7" + meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid() + meta.TransferSyntaxUID = ExplicitVRLittleEndian + meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID + + ds = Dataset() + ds.file_meta = meta + ds.is_little_endian = True + ds.is_implicit_VR = False + ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.7" + ds.SOPInstanceUID = pydicom.uid.generate_uid() + ds.PatientName = "John Doe" + ds.PatientID = "123456" + ds.Modality = "OT" + ds.SeriesInstanceUID = pydicom.uid.generate_uid() + ds.StudyInstanceUID = pydicom.uid.generate_uid() + ds.FrameOfReferenceUID = pydicom.uid.generate_uid() + ds.BitsAllocated = 8 + ds.BitsStored = 8 + ds.HighBit = 7 + ds.PhotometricInterpretation = "RGB" + ds.PixelRepresentation = 0 + ds.Rows = image_array.shape[0] + ds.Columns = image_array.shape[1] + ds.SamplesPerPixel = 3 + ds.PlanarConfiguration = 0 + + if plane.lower() == "axial": + ds.ImageOrientationPatient = [1, 0, 0, 0, 1, 0] + elif plane.lower() == "sagittal": + ds.ImageOrientationPatient = [0, 1, 0, 0, 0, -1] + elif plane.lower() == "coronal": + ds.ImageOrientationPatient = [1, 0, 0, 0, 0, -1] + else: + raise ValueError( + "Invalid plane value. Must be 'axial', 'sagittal', or 'coronal'." + ) + + ds.PixelData = image_array.tobytes() + pydicom.filewriter.write_file(dicom_path, ds, write_like_original=False) + + +# Example usage +if __name__ == "__main__": + png_path = "../../figures/spine_example.png" + output_path = "./" + plane = "sagittal" + to_dicom(png_path, output_path, plane) diff --git a/Comp2Comp-main/comp2comp/visualization/linear_planar_reformation.py b/Comp2Comp-main/comp2comp/visualization/linear_planar_reformation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb444eda419d65b8f7544ad54737affc73cbe64a --- /dev/null +++ b/Comp2Comp-main/comp2comp/visualization/linear_planar_reformation.py @@ -0,0 +1,96 @@ +""" +@author: louisblankemeier +""" + +import numpy as np + + +def linear_planar_reformation( + medical_volume: np.ndarray, segmentation: np.ndarray, centroids, dimension="axial" +): + if dimension == "sagittal" or dimension == "coronal": + centroids = sorted(centroids, key=lambda x: x[2]) + elif dimension == "axial": + centroids = sorted(centroids, key=lambda x: x[0]) + + centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids] + sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))] + coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))] + axial_centroids = [centroids[i][2] for i in range(0, len(centroids))] + + sagittal_vals, coronal_vals, axial_vals = [], [], [] + + if dimension == "sagittal": + sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0] + + if dimension == "coronal": + coronal_vals = [coronal_centroids[0]] * axial_centroids[0] + + if dimension == "axial": + axial_vals = [axial_centroids[0]] * sagittal_centroids[0] + + for i in range(1, len(axial_centroids)): + if dimension == "sagittal" or dimension == "coronal": + num = axial_centroids[i] - axial_centroids[i - 1] + elif dimension == "axial": + num = sagittal_centroids[i] - sagittal_centroids[i - 1] + + if dimension == "sagittal": + interp = list( + np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num) + ) + sagittal_vals.extend(interp) + + if dimension == "coronal": + interp = list( + np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num) + ) + coronal_vals.extend(interp) + + if dimension == "axial": + interp = list( + np.linspace(axial_centroids[i - 1], axial_centroids[i], num=num) + ) + axial_vals.extend(interp) + + if dimension == "sagittal": + sagittal_vals.extend( + [sagittal_centroids[-1]] * (medical_volume.shape[2] - len(sagittal_vals)) + ) + sagittal_vals = np.array(sagittal_vals) + sagittal_vals = sagittal_vals.astype(int) + + if dimension == "coronal": + coronal_vals.extend( + [coronal_centroids[-1]] * (medical_volume.shape[2] - len(coronal_vals)) + ) + coronal_vals = np.array(coronal_vals) + coronal_vals = coronal_vals.astype(int) + + if dimension == "axial": + axial_vals.extend( + [axial_centroids[-1]] * (medical_volume.shape[0] - len(axial_vals)) + ) + axial_vals = np.array(axial_vals) + axial_vals = axial_vals.astype(int) + + if dimension == "sagittal": + sagittal_image = medical_volume[sagittal_vals, :, range(len(sagittal_vals))] + sagittal_label = segmentation[sagittal_vals, :, range(len(sagittal_vals))] + + if dimension == "coronal": + coronal_image = medical_volume[:, coronal_vals, range(len(coronal_vals))] + coronal_label = segmentation[:, coronal_vals, range(len(coronal_vals))] + + if dimension == "axial": + axial_image = medical_volume[range(len(axial_vals)), :, axial_vals] + axial_label = segmentation[range(len(axial_vals)), :, axial_vals] + + if dimension == "sagittal": + return sagittal_image, sagittal_label + + if dimension == "coronal": + return coronal_image, coronal_label + + if dimension == "axial": + return axial_image, axial_label diff --git a/Comp2Comp-main/docs/Local Implementation @ M1 arm64 Silicon.md b/Comp2Comp-main/docs/Local Implementation @ M1 arm64 Silicon.md new file mode 100644 index 0000000000000000000000000000000000000000..4071ebdd5635f9a440d55df95e32b8609a272a51 --- /dev/null +++ b/Comp2Comp-main/docs/Local Implementation @ M1 arm64 Silicon.md @@ -0,0 +1,67 @@ +# Local Implementation @ M1/arm64/AppleSilicon + +Due to dependencies and differences in architecture, the direct installation of *Comp2Comp* using install.sh or setup.py did not work on an local machine with arm64 / apple silicon running MacOS. This guide is mainly based on [issue #30](https://github.com/StanfordMIMI/Comp2Comp/issues/30). Most of the problems I encountered are caused by requiring TensorFlow and PyTorch in the same environment, which (especially for TensorFlow) is tricky at some times. Thus, this guide focuses more on the setup of the environment @arm64 / AppleSilicon, than *Comp2Comp* or *TotalSegmentator* itself. + +## Installation +Comp2Comp requires TensorFlow and TotalSegmentator requires PyTorch. Although (at the moment) neither *Comp2Comp* nor *TotalSegmentator* can make use of the M1 GPU. Thus, using the arm64-specific versions is necessary. + +### TensorFlow +For reference: +- https://developer.apple.com/metal/tensorflow-plugin/ +- https://developer.apple.com/forums/thread/683757 +- https://developer.apple.com/forums/thread/686926?page=2 + +1. Create an environment (python 3.8 or 3.9) using miniforge: https://github.com/conda-forge/miniforge. (TensorFlow did not work for others using anaconda; maybe you can get it running using -c apple and -c conda-forge for the further steps. However, I am not sure whether just the channel (and the retrieved packages) or anaconda's python itself is the problem.) + +2. Install TensorFlow and tensorflow-metal in these versions: +``` +conda install -c apple tensorflow-deps=2.9.0 -y +python -m pip install tensorflow-macos==2.9 +python -m pip install tensorflow-metal==0.5.0 +``` +If you use other methods to install tensorflow, version 2.11.0 might be the best option. Tensorflow version 2.12.0 has caused some problems. + +### PyTorch +For reference https://pytorch.org. The nightly build is (at least for -c conda-forge or -c pytorch) not needed, and the default already supports GPU acceleration on arm64. + +3. Install Pytorch +``` +conda install pytorch torchvision torchaudio -c pytorch +``` + +### Other Dependencies (Numpy and scikit-learn) +4. Install other packages +``` +conda install -c conda-forge numpy scikit-learn -y +``` + +### TotalSegmentator +Louis et al. modified the original *TotalSegmentator* (https://github.com/wasserth/TotalSegmentator) for the use with *Comp2Comp*. *Comp2Comp* does not work with the original version. With the current version of the modified *TotalSegmentator* (https://github.com/StanfordMIMI/TotalSegmentator), no adaptions are necessary. + +### Comp2Comp +For *Comp2Comp* on M1 however, it is important **not** to use bin/install.sh, as some of the predefined requirements won't work. Thus: + +5. Clone *Comp2Comp* +``` +git clone https://github.com/StanfordMIMI/Comp2Comp.git +``` + +6. Modify setup.py by +- remove `"numpy==1.23.5"` +- remove `"tensorflow>=2.0.0"` + +(You have installed these manually before.) + +7. Install *Comp2Comp* with +``` +python -m pip install -e . +``` + +## Performance +Using M1Max w/ 64GB RAM +- `process 2d` (Comp2Comp in predefined slices): 250 slices in 14.2sec / 361 slices in 17.9sec +- `process 3d` (segmentation of spine and identification of slices using TotalSegmentator, Comp2Comp in identified slices): high res, full body scan, 1367sec + +## ToDos / Nice2Have / Future +- Integration and use `--fast` and `--body_seg` for TotalSegmentator might be preferable +- TotalSegmentator works only with CUDA compatible GPUs (!="mps"). I am not sure, about `torch.device("mps")` in the future, see also https://github.com/wasserth/TotalSegmentator/issues/39. Currently, only the CPU is used. diff --git a/Comp2Comp-main/docs/Makefile b/Comp2Comp-main/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293 --- /dev/null +++ b/Comp2Comp-main/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/Comp2Comp-main/docs/make.bat b/Comp2Comp-main/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/Comp2Comp-main/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/Comp2Comp-main/docs/requirements.txt b/Comp2Comp-main/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..61feca2e6b0d510be73bb1994aac4459ce43b0c1 --- /dev/null +++ b/Comp2Comp-main/docs/requirements.txt @@ -0,0 +1,6 @@ +sphinx +sphinx-rtd-theme +recommonmark +sphinx_bootstrap_theme +sphinxcontrib-bibtex>=2.0.0 +m2r2 \ No newline at end of file diff --git a/Comp2Comp-main/docs/source/conf.py b/Comp2Comp-main/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..f374c3872a7bd29d350b6662a3eff7aa1dbceb79 --- /dev/null +++ b/Comp2Comp-main/docs/source/conf.py @@ -0,0 +1,57 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "comp2comp" +copyright = "2023, StanfordMIMI" +author = "StanfordMIMI" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +# Adapted from https://github.com/pyvoxel/pyvoxel + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.napoleon", + "sphinxcontrib.bibtex", + "sphinx_rtd_theme", + "sphinx.ext.githubpages", + "m2r2", +] + +autosummary_generate = True +autosummary_imported_members = True + +bibtex_bibfiles = ["references.bib"] + +templates_path = ["_templates"] +exclude_patterns = [] + + +pygments_style = "sphinx" +html_theme = "sphinx_rtd_theme" +htmlhelp_basename = "Comp2Compdoc" +html_static_path = ["_static"] + +intersphinx_mapping = {"numpy": ("https://numpy.org/doc/stable/", None)} +html_theme_options = {"navigation_depth": 2} + +source_suffix = [".rst", ".md"] + +todo_include_todos = True +napoleon_use_ivar = True +napoleon_google_docstring = True +html_show_sourcelink = False diff --git a/Comp2Comp-main/docs/source/index.rst b/Comp2Comp-main/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..b3598af1ea9badb8a06c74d46cd0af1498fccb21 --- /dev/null +++ b/Comp2Comp-main/docs/source/index.rst @@ -0,0 +1,13 @@ +.. comp2comp documentation master file, created by + sphinx-quickstart on Sun Apr 9 21:28:41 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to comp2comp's documentation! +===================================== + +.. mdinclude:: ../../README.md + +.. toctree:: + :maxdepth: 2 + :hidden: diff --git a/Comp2Comp-main/figures/aaa_diameter_graph.png b/Comp2Comp-main/figures/aaa_diameter_graph.png new file mode 100644 index 0000000000000000000000000000000000000000..62bc91d880fc13522d40461b425f642ee4363c78 Binary files /dev/null and b/Comp2Comp-main/figures/aaa_diameter_graph.png differ diff --git a/Comp2Comp-main/figures/aaa_segmentation_video.gif b/Comp2Comp-main/figures/aaa_segmentation_video.gif new file mode 100644 index 0000000000000000000000000000000000000000..892fb6d4d04293f36f63b09446e911eacf141113 --- /dev/null +++ b/Comp2Comp-main/figures/aaa_segmentation_video.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22075134c1b6dfad6e9bebdfa8f6ef414fd9284414cc059dfe59482137f20922 +size 9557543 diff --git a/Comp2Comp-main/figures/aortic_aneurysm_example.png b/Comp2Comp-main/figures/aortic_aneurysm_example.png new file mode 100644 index 0000000000000000000000000000000000000000..4a97baa3616a7bc885feab431aac2ed58dbee691 Binary files /dev/null and b/Comp2Comp-main/figures/aortic_aneurysm_example.png differ diff --git a/Comp2Comp-main/figures/hip_example.png b/Comp2Comp-main/figures/hip_example.png new file mode 100644 index 0000000000000000000000000000000000000000..7523de50b0783d415304b2c587b9e15b4cff1b00 Binary files /dev/null and b/Comp2Comp-main/figures/hip_example.png differ diff --git a/Comp2Comp-main/figures/liver_spleen_pancreas_example.png b/Comp2Comp-main/figures/liver_spleen_pancreas_example.png new file mode 100644 index 0000000000000000000000000000000000000000..5e53763cf4c4aafc0dcbe467445f8b71bd8161d8 --- /dev/null +++ b/Comp2Comp-main/figures/liver_spleen_pancreas_example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8273c031db13d59e4b60bdf841337ab551cd78652abcc1ae01fe9fb660ead9a +size 7546828 diff --git a/Comp2Comp-main/figures/muscle_adipose_tissue_example.png b/Comp2Comp-main/figures/muscle_adipose_tissue_example.png new file mode 100644 index 0000000000000000000000000000000000000000..858a326e03f3e35cb470c753ec73fd22617c83b2 Binary files /dev/null and b/Comp2Comp-main/figures/muscle_adipose_tissue_example.png differ diff --git a/Comp2Comp-main/figures/spine_example.png b/Comp2Comp-main/figures/spine_example.png new file mode 100644 index 0000000000000000000000000000000000000000..6fa29c24b13a105605b6b8f27f421629f302b03f Binary files /dev/null and b/Comp2Comp-main/figures/spine_example.png differ diff --git a/Comp2Comp-main/figures/spine_muscle_adipose_tissue_example.png b/Comp2Comp-main/figures/spine_muscle_adipose_tissue_example.png new file mode 100644 index 0000000000000000000000000000000000000000..e4672d3e69a4446195df1749c89caea3c5f0c6dc --- /dev/null +++ b/Comp2Comp-main/figures/spine_muscle_adipose_tissue_example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e27a337261bbcbf06655bd48aba1bfb00dc4dba88d36b564db48df358afa1871 +size 1112753 diff --git a/Comp2Comp-main/logo.png b/Comp2Comp-main/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..20a44ec44bd42ad3c5f5497dffd6f9e5b8db0c53 Binary files /dev/null and b/Comp2Comp-main/logo.png differ diff --git a/Comp2Comp-main/setup.cfg b/Comp2Comp-main/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..cfe245ef3c16d6b95086c48c466f5dcf898bb21e --- /dev/null +++ b/Comp2Comp-main/setup.cfg @@ -0,0 +1,21 @@ +[isort] +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +ensure_newline_before_comments=True +line_length=80 + +[mypy] +python_version=3.6 +ignore_missing_imports = True +warn_unused_configs = True +disallow_untyped_defs = True +check_untyped_defs = True +warn_unused_ignores = True +warn_redundant_casts = True +show_column_numbers = True +follow_imports = silent +allow_redefinition = True +; Require all functions to be annotated +disallow_incomplete_defs = True \ No newline at end of file diff --git a/Comp2Comp-main/setup.py b/Comp2Comp-main/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..578e336522fe7200011f06d7f31f49652075ef14 --- /dev/null +++ b/Comp2Comp-main/setup.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import os +from os import path + +from setuptools import find_packages, setup + + +def get_version(): + init_py_path = path.join( + path.abspath(path.dirname(__file__)), "comp2comp", "__init__.py" + ) + init_py = open(init_py_path, "r").readlines() + version_line = [line.strip() for line in init_py if line.startswith("__version__")][ + 0 + ] + version = version_line.split("=")[-1].strip().strip("'\"") + + # The following is used to build release packages. + # Users should never use it. + suffix = os.getenv("ABCTSEG_VERSION_SUFFIX", "") + version = version + suffix + if os.getenv("BUILD_NIGHTLY", "0") == "1": + from datetime import datetime + + date_str = datetime.today().strftime("%y%m%d") + version = version + ".dev" + date_str + + new_init_py = [line for line in init_py if not line.startswith("__version__")] + new_init_py.append('__version__ = "{}"\n'.format(version)) + with open(init_py_path, "w") as f: + f.write("".join(new_init_py)) + return version + + +setup( + name="comp2comp", + version=get_version(), + author="StanfordMIMI", + url="https://github.com/StanfordMIMI/Comp2Comp", + description="Computed tomography to body composition.", + packages=find_packages(exclude=("configs", "tests")), + python_requires=">=3.9", + install_requires=[ + "pydicom", + "moviepy", + "numpy==1.23.5", + "h5py", + "tabulate", + "tqdm", + "silx", + "yacs", + "pandas", + "dosma", + "opencv-python", + "huggingface_hub", + "pycocotools", + "wget", + "tensorflow==2.12.0", + "totalsegmentator @ git+https://github.com/StanfordMIMI/TotalSegmentator.git", + "totalsegmentatorv2 @ git+https://github.com/StanfordMIMI/TotalSegmentatorV2.git", + ], + extras_require={ + "all": ["shapely", "psutil"], + "dev": [ + # Formatting + "flake8", + "isort", + "black==22.8.0", + "flake8-bugbear", + "flake8-comprehensions", + # Docs + "mock", + "sphinx", + "sphinx-rtd-theme", + "recommonmark", + "myst-parser", + ], + "contrast_phase": ["xgboost"], + }, +)