AdritRao commited on
Commit
a3290d1
1 Parent(s): 2033578

Upload 62 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. Comp2Comp-main/.github/workflows/format.yml +33 -0
  3. Comp2Comp-main/.gitignore +71 -0
  4. Comp2Comp-main/Dockerfile +5 -0
  5. Comp2Comp-main/LICENSE +201 -0
  6. Comp2Comp-main/README.md +197 -0
  7. Comp2Comp-main/bin/C2C +276 -0
  8. Comp2Comp-main/bin/C2C-slurm +46 -0
  9. Comp2Comp-main/bin/install.sh +146 -0
  10. Comp2Comp-main/comp2comp/__init__.py +8 -0
  11. Comp2Comp-main/comp2comp/aaa/aaa.py +424 -0
  12. Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium.py +408 -0
  13. Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium_visualization.py +119 -0
  14. Comp2Comp-main/comp2comp/contrast_phase/contrast_inf.py +466 -0
  15. Comp2Comp-main/comp2comp/contrast_phase/contrast_phase.py +116 -0
  16. Comp2Comp-main/comp2comp/contrast_phase/xgboost.pkl +3 -0
  17. Comp2Comp-main/comp2comp/hip/hip.py +301 -0
  18. Comp2Comp-main/comp2comp/hip/hip_utils.py +362 -0
  19. Comp2Comp-main/comp2comp/hip/hip_visualization.py +171 -0
  20. Comp2Comp-main/comp2comp/hip/tunnelvision.ipynb +73 -0
  21. Comp2Comp-main/comp2comp/inference_class_base.py +18 -0
  22. Comp2Comp-main/comp2comp/inference_pipeline.py +102 -0
  23. Comp2Comp-main/comp2comp/io/io.py +138 -0
  24. Comp2Comp-main/comp2comp/io/io_utils.py +77 -0
  25. Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py +95 -0
  26. Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py +130 -0
  27. Comp2Comp-main/comp2comp/liver_spleen_pancreas/visualization_utils.py +332 -0
  28. Comp2Comp-main/comp2comp/metrics/metrics.py +156 -0
  29. Comp2Comp-main/comp2comp/models/models.py +157 -0
  30. Comp2Comp-main/comp2comp/muscle_adipose_tissue/data.py +214 -0
  31. Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py +445 -0
  32. Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py +181 -0
  33. Comp2Comp-main/comp2comp/spine/spine.py +483 -0
  34. Comp2Comp-main/comp2comp/spine/spine_utils.py +737 -0
  35. Comp2Comp-main/comp2comp/spine/spine_visualization.py +198 -0
  36. Comp2Comp-main/comp2comp/utils/__init__.py +0 -0
  37. Comp2Comp-main/comp2comp/utils/colormap.py +156 -0
  38. Comp2Comp-main/comp2comp/utils/dl_utils.py +80 -0
  39. Comp2Comp-main/comp2comp/utils/env.py +84 -0
  40. Comp2Comp-main/comp2comp/utils/logger.py +209 -0
  41. Comp2Comp-main/comp2comp/utils/orientation.py +30 -0
  42. Comp2Comp-main/comp2comp/utils/process.py +120 -0
  43. Comp2Comp-main/comp2comp/utils/run.py +126 -0
  44. Comp2Comp-main/comp2comp/visualization/detectron_visualizer.py +1288 -0
  45. Comp2Comp-main/comp2comp/visualization/dicom.py +73 -0
  46. Comp2Comp-main/comp2comp/visualization/linear_planar_reformation.py +96 -0
  47. Comp2Comp-main/docs/Local Implementation @ M1 arm64 Silicon.md +67 -0
  48. Comp2Comp-main/docs/Makefile +20 -0
  49. Comp2Comp-main/docs/make.bat +35 -0
  50. Comp2Comp-main/docs/requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Comp2Comp-main/figures/aaa_segmentation_video.gif filter=lfs diff=lfs merge=lfs -text
37
+ Comp2Comp-main/figures/liver_spleen_pancreas_example.png filter=lfs diff=lfs merge=lfs -text
38
+ Comp2Comp-main/figures/spine_muscle_adipose_tissue_example.png filter=lfs diff=lfs merge=lfs -text
Comp2Comp-main/.github/workflows/format.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Autoformat code
2
+
3
+ on:
4
+ push:
5
+ branches: [ 'main' ]
6
+ pull_request:
7
+ branches: [ 'main' ]
8
+
9
+ jobs:
10
+ format:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ - name: Format code
15
+ run: |
16
+ pip install black
17
+ black .
18
+ - name: Sort imports
19
+ run: |
20
+ pip install isort
21
+ isort .
22
+ - name: Remove unused imports
23
+ run: |
24
+ pip install autoflake
25
+ autoflake --in-place --remove-all-unused-imports --remove-unused-variables --recursive .
26
+ - name: Commit changes
27
+ uses: EndBug/add-and-commit@v4
28
+ with:
29
+ author_name: ${{ github.actor }}
30
+ author_email: ${{ github.actor }}@users.noreply.github.com
31
+ message: "Autoformat code"
32
+ add: "."
33
+ branch: ${{ github.ref }}
Comp2Comp-main/.gitignore ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore project files
2
+ **/.idea
3
+ **/.DS_Store
4
+ **/.vscode
5
+
6
+ # Ignore cache
7
+ **/__pycache__
8
+
9
+ # Ignore egg files
10
+ **/*.egg-info
11
+
12
+ # Docs build files
13
+ docs/_build
14
+
15
+ # Ignore tensorflow logs
16
+ **/tf_log
17
+
18
+ # Ignore results
19
+ **/pik_data
20
+ **/preds
21
+
22
+ # Ignore test_data
23
+ **/test_data
24
+ **/testing_data
25
+ **/sample_data
26
+ **/test_results
27
+
28
+ # Ignore images
29
+ **/model_imgs
30
+
31
+ # Ignore data visualization scripts/images
32
+ **/data_visualization
33
+ **/OAI-iMorphics
34
+
35
+ # temp files
36
+ ._*
37
+ # ignore checkpoint files
38
+ **/.ipynb_checkpoints/
39
+ **/.comp2comp/
40
+
41
+ # ignore cross validation files
42
+ *.cv
43
+
44
+ # ignore yml file
45
+ *.yml
46
+ *.yaml
47
+ !.github/workflows/format.yml
48
+
49
+ # ignore images
50
+ *.png
51
+ !panel_example.png
52
+ !logo.png
53
+ # except for pngs in the figures folder
54
+ !figures/*.png
55
+
56
+ # ignore any weights files
57
+ weights/
58
+
59
+ # preferences file
60
+ comp2comp/preferences.yaml
61
+
62
+ # model directory
63
+ **/.comp2comp_model_dir/
64
+
65
+ # slurm outputs
66
+ **/slurm/
67
+
68
+ # ignore outputs file
69
+ **/outputs/
70
+
71
+ **/models/
Comp2Comp-main/Dockerfile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FROM python:3.8
2
+ COPY . /Comp2Comp
3
+ WORKDIR /Comp2Comp
4
+ RUN pip install -e .
5
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
Comp2Comp-main/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Comp2Comp-main/README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <img src="logo.png" width="40" height="40" /> Comp2Comp
2
+ [![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3
+ ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/StanfordMIMI/Comp2Comp/format.yml?branch=master)
4
+ [![Documentation Status](https://readthedocs.org/projects/comp2comp/badge/?version=latest)](https://comp2comp.readthedocs.io/en/latest/?badge=latest)
5
+
6
+ [**Paper**](https://arxiv.org/abs/2302.06568)
7
+ | [**Installation**](#installation)
8
+ | [**Basic Usage**](#basic_usage)
9
+ | [**Inference Pipelines**](#basic_usage)
10
+ | [**Contribute**](#contribute)
11
+ | [**Citation**](#citation)
12
+
13
+ Comp2Comp is a library for extracting clinical insights from computed tomography scans.
14
+
15
+ ## Installation
16
+ <a name="installation"></a>
17
+ ```bash
18
+ git clone https://github.com/StanfordMIMI/Comp2Comp/
19
+
20
+ # Install script requires Anaconda/Miniconda.
21
+ cd Comp2Comp && bin/install.sh
22
+ ```
23
+
24
+ Alternatively, Comp2Comp can be installed with `pip`:
25
+ ```bash
26
+ git clone https://github.com/StanfordMIMI/Comp2Comp/
27
+ cd Comp2Comp
28
+ conda create -n c2c_env python=3.8
29
+ conda activate c2c_env
30
+ pip install -e .
31
+ ```
32
+
33
+ For installing on the Apple M1 chip, see [these instructions](https://github.com/StanfordMIMI/Comp2Comp/blob/master/docs/Local%20Implementation%20%40%20M1%20arm64%20Silicon.md).
34
+
35
+ ## Basic Usage
36
+ <a name="basic_usage"></a>
37
+ ```bash
38
+ bin/C2C <pipeline_name> -i <path/to/input/folder>
39
+ ```
40
+
41
+ For running on slurm, modify the above commands as follow:
42
+ ```bash
43
+ bin/C2C-slurm <pipeline_name> -i <path/to/input/folder>
44
+ ```
45
+
46
+ ## Inference Pipelines
47
+ <a name="inference_pipeline"></a>
48
+ We have designed Comp2Comp to be highly extensible and to enable the development of complex clinically-relevant applications. We observed that many clinical applications require chaining several machine learning or other computational modules together to generate complex insights. The inference pipeline system is designed to make this easy. Furthermore, we seek to make the code readable and modular, so that the community can easily contribute to the project.
49
+
50
+ The [`InferencePipeline` class](comp2comp/inference_pipeline.py) is used to create inference pipelines, which are made up of a sequence of [`InferenceClass` objects](comp2comp/inference_class_base.py). When the `InferencePipeline` object is called, it sequentially calls the `InferenceClasses` that were provided to the constructor.
51
+
52
+ The first argument of the `__call__` function of `InferenceClass` must be the `InferencePipeline` object. This allows each `InferenceClass` object to access or set attributes of the `InferencePipeline` object that can be accessed by the subsequent `InferenceClass` objects in the pipeline. Each `InferenceClass` object should return a dictionary where the keys of the dictionary should match the keyword arguments of the subsequent `InferenceClass's` `__call__` function. If an `InferenceClass` object only sets attributes of the `InferencePipeline` object but does not return any value, an empty dictionary can be returned.
53
+
54
+ Below are the inference pipelines currently supported by Comp2Comp.
55
+
56
+ ## Spine Bone Mineral Density from 3D Trabecular Bone Regions at T12-L5
57
+
58
+ ### Usage
59
+ ```bash
60
+ bin/C2C spine -i <path/to/input/folder>
61
+ ```
62
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
63
+
64
+ ### Example Output Image
65
+ <p align="center">
66
+ <img src="figures/spine_example.png" height="300">
67
+ </p>
68
+
69
+ ## End-to-End Spine, Muscle, and Adipose Tissue Analysis at T12-L5
70
+
71
+ ### Usage
72
+ ```bash
73
+ bin/C2C spine_muscle_adipose_tissue -i <path/to/input/folder>
74
+ ```
75
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
76
+
77
+ ### Example Output Image
78
+ <p align="center">
79
+ <img src="figures/spine_muscle_adipose_tissue_example.png" height="300">
80
+ </p>
81
+
82
+ ## AAA Segmentation and Maximum Diameter Measurement
83
+
84
+ ### Usage
85
+ ```bash
86
+ bin/C2C aaa -i <path/to/input/folder>
87
+ ```
88
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
89
+
90
+ ### Example Output Image (slice with largest diameter)
91
+ <p align="center">
92
+ <img src="figures/aortic_aneurysm_example.png" height="300">
93
+ </p>
94
+
95
+ <div align="center">
96
+
97
+ | Example Output Video | Example Output Graph |
98
+ |-----------------------------|----------------------------|
99
+ | <p align="center"><img src="figures/aaa_segmentation_video.gif" height="300"></p> | <p align="center"><img src="figures/aaa_diameter_graph.png" height="300"></p> |
100
+
101
+ </div>
102
+
103
+ ## Contrast Phase Detection
104
+
105
+ ### Usage
106
+ ```bash
107
+ bin/C2C contrast_phase -i <path/to/input/folder>
108
+ ```
109
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
110
+ - This package has extra dependencies. To install those, run:
111
+ ```bash
112
+ cd Comp2Comp
113
+ pip install -e '.[contrast_phase]'
114
+ ```
115
+
116
+ ## 3D Analysis of Liver, Spleen, and Pancreas
117
+
118
+ ### Usage
119
+ ```bash
120
+ bin/C2C liver_spleen_pancreas -i <path/to/input/folder>
121
+ ```
122
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
123
+
124
+ ### Example Output Image
125
+ <p align="center">
126
+ <img src="figures/liver_spleen_pancreas_example.png" height="300">
127
+ </p>
128
+
129
+ ## 3D Analysis of the Femur
130
+
131
+ ### Usage
132
+ ```bash
133
+ bin/C2C hip -i <path/to/input/folder>
134
+ ```
135
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
136
+
137
+ ### Example Output Image
138
+ <p align="center">
139
+ <img src="figures/hip_example.png" height="300">
140
+ </p>
141
+
142
+ ## Abdominal Aortic Calcification Segmentation
143
+
144
+ ### Usage
145
+ ```bash
146
+ bin/C2C aortic_calcium -i <path/to/input/folder>
147
+ ```
148
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
149
+
150
+ ### Example Output
151
+ ```
152
+ Statistics on aortic calcifications:
153
+ Total number: 7
154
+ Total volume (cm³): 0.348
155
+ Mean HU: 570.3+/-85.8
156
+ Median HU: 544.2+/-85.3
157
+ Max HU: 981.7+/-266.4
158
+ Mean volume (cm³): 0.005+/-0.059
159
+ Median volume (cm³): 0.022
160
+ Max volume (cm³): 0.184
161
+ Min volume (cm³): 0.005
162
+ ```
163
+
164
+ ## Pipeline that runs all currently implemented pipelines
165
+
166
+ ### Usage
167
+ ```bash
168
+ bin/C2C all -i <path/to/input/folder>
169
+ ```
170
+ - input_path should contain a DICOM series or subfolders that contain DICOM series.
171
+
172
+ ## Contribute
173
+ <a name="contribute"></a>
174
+ We welcome all pull requests. If you have any issues, suggestions, or feedback, please open a new issue.
175
+
176
+ ## Citation
177
+ <a name="citation"></a>
178
+ ```
179
+ @article{blankemeier2023comp2comp,
180
+ title={Comp2Comp: Open-Source Body Composition Assessment on Computed Tomography},
181
+ author={Blankemeier, Louis and Desai, Arjun and Chaves, Juan Manuel Zambrano and Wentland, Andrew and Yao, Sally and Reis, Eduardo and Jensen, Malte and Bahl, Bhanushree and Arora, Khushboo and Patel, Bhavik N and others},
182
+ journal={arXiv preprint arXiv:2302.06568},
183
+ year={2023}
184
+ }
185
+ ```
186
+
187
+ In addition to Comp2Comp, please consider citing TotalSegmentator:
188
+ ```
189
+ @article{wasserthal2022totalsegmentator,
190
+ title={TotalSegmentator: robust segmentation of 104 anatomical structures in CT images},
191
+ author={Wasserthal, Jakob and Meyer, Manfred and Breit, Hanns-Christian and Cyriac, Joshy and Yang, Shan and Segeroth, Martin},
192
+ journal={arXiv preprint arXiv:2208.05868},
193
+ year={2022}
194
+ }
195
+ ```
196
+
197
+
Comp2Comp-main/bin/C2C ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import os
4
+
5
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+ os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
7
+
8
+ from comp2comp.aaa import aaa
9
+ from comp2comp.aortic_calcium import (
10
+ aortic_calcium,
11
+ aortic_calcium_visualization,
12
+ )
13
+ from comp2comp.contrast_phase.contrast_phase import ContrastPhaseDetection
14
+ from comp2comp.hip import hip
15
+ from comp2comp.inference_pipeline import InferencePipeline
16
+ from comp2comp.io import io
17
+ from comp2comp.liver_spleen_pancreas import (
18
+ liver_spleen_pancreas,
19
+ liver_spleen_pancreas_visualization,
20
+ )
21
+ from comp2comp.muscle_adipose_tissue import (
22
+ muscle_adipose_tissue,
23
+ muscle_adipose_tissue_visualization,
24
+ )
25
+ from comp2comp.spine import spine
26
+ from comp2comp.utils import orientation
27
+ from comp2comp.utils.process import process_3d
28
+
29
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
30
+
31
+ ### AAA Pipeline
32
+
33
+ def AAAPipelineBuilder(path, args):
34
+ pipeline = InferencePipeline(
35
+ [
36
+ AxialCropperPipelineBuilder(path, args),
37
+ aaa.AortaSegmentation(),
38
+ aaa.AortaDiameter(),
39
+ aaa.AortaMetricsSaver()
40
+ ]
41
+ )
42
+ return pipeline
43
+
44
+ def MuscleAdiposeTissuePipelineBuilder(args):
45
+ pipeline = InferencePipeline(
46
+ [
47
+ muscle_adipose_tissue.MuscleAdiposeTissueSegmentation(
48
+ 16, args.muscle_fat_model
49
+ ),
50
+ muscle_adipose_tissue.MuscleAdiposeTissuePostProcessing(),
51
+ muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(),
52
+ muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(),
53
+ muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(),
54
+ muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(),
55
+ ]
56
+ )
57
+ return pipeline
58
+
59
+
60
+ def MuscleAdiposeTissueFullPipelineBuilder(args):
61
+ pipeline = InferencePipeline(
62
+ [io.DicomFinder(args.input_path), MuscleAdiposeTissuePipelineBuilder(args)]
63
+ )
64
+ return pipeline
65
+
66
+
67
+ def SpinePipelineBuilder(path, args):
68
+ pipeline = InferencePipeline(
69
+ [
70
+ io.DicomToNifti(path),
71
+ spine.SpineSegmentation(args.spine_model, save=True),
72
+ orientation.ToCanonical(),
73
+ spine.SpineComputeROIs(args.spine_model),
74
+ spine.SpineMetricsSaver(),
75
+ spine.SpineCoronalSagittalVisualizer(format="png"),
76
+ spine.SpineReport(format="png"),
77
+ ]
78
+ )
79
+ return pipeline
80
+
81
+
82
+ def AxialCropperPipelineBuilder(path, args):
83
+ pipeline = InferencePipeline(
84
+ [
85
+ io.DicomToNifti(path),
86
+ spine.SpineSegmentation(args.spine_model),
87
+ orientation.ToCanonical(),
88
+ spine.AxialCropper(lower_level="L5", upper_level="L1", save=True),
89
+ ]
90
+ )
91
+ return pipeline
92
+
93
+
94
+ def SpineMuscleAdiposeTissuePipelineBuilder(path, args):
95
+ pipeline = InferencePipeline(
96
+ [
97
+ SpinePipelineBuilder(path, args),
98
+ spine.SpineFindDicoms(),
99
+ MuscleAdiposeTissuePipelineBuilder(args),
100
+ spine.SpineMuscleAdiposeTissueReport(),
101
+ ]
102
+ )
103
+ return pipeline
104
+
105
+
106
+ def LiverSpleenPancreasPipelineBuilder(path, args):
107
+ pipeline = InferencePipeline(
108
+ [
109
+ io.DicomToNifti(path),
110
+ liver_spleen_pancreas.LiverSpleenPancreasSegmentation(),
111
+ orientation.ToCanonical(),
112
+ liver_spleen_pancreas_visualization.LiverSpleenPancreasVisualizer(),
113
+ liver_spleen_pancreas_visualization.LiverSpleenPancreasMetricsPrinter(),
114
+ ]
115
+ )
116
+ return pipeline
117
+
118
+
119
+ def AorticCalciumPipelineBuilder(path, args):
120
+ pipeline = InferencePipeline(
121
+ [
122
+ io.DicomToNifti(path),
123
+ spine.SpineSegmentation(model_name="ts_spine"),
124
+ orientation.ToCanonical(),
125
+ aortic_calcium.AortaSegmentation(),
126
+ orientation.ToCanonical(),
127
+ aortic_calcium.AorticCalciumSegmentation(),
128
+ aortic_calcium.AorticCalciumMetrics(),
129
+ aortic_calcium_visualization.AorticCalciumVisualizer(),
130
+ aortic_calcium_visualization.AorticCalciumPrinter(),
131
+ ]
132
+ )
133
+ return pipeline
134
+
135
+
136
+ def ContrastPhasePipelineBuilder(path, args):
137
+ pipeline = InferencePipeline([io.DicomToNifti(path), ContrastPhaseDetection(path)])
138
+ return pipeline
139
+
140
+
141
+ def HipPipelineBuilder(path, args):
142
+ pipeline = InferencePipeline(
143
+ [
144
+ io.DicomToNifti(path),
145
+ hip.HipSegmentation(args.hip_model),
146
+ orientation.ToCanonical(),
147
+ hip.HipComputeROIs(args.hip_model),
148
+ hip.HipMetricsSaver(),
149
+ hip.HipVisualizer(),
150
+ ]
151
+ )
152
+ return pipeline
153
+
154
+
155
+ def AllPipelineBuilder(path, args):
156
+ pipeline = InferencePipeline(
157
+ [
158
+ io.DicomToNifti(path),
159
+ SpineMuscleAdiposeTissuePipelineBuilder(path, args),
160
+ LiverSpleenPancreasPipelineBuilder(path, args),
161
+ HipPipelineBuilder(path, args),
162
+ ]
163
+ )
164
+ return pipeline
165
+
166
+
167
+ def argument_parser():
168
+ base_parser = argparse.ArgumentParser(add_help=False)
169
+ base_parser.add_argument("--input_path", "-i", type=str, required=True)
170
+ base_parser.add_argument("--output_path", "-o", type=str)
171
+ base_parser.add_argument("--save_segmentations", action="store_true")
172
+ base_parser.add_argument("--overwrite_outputs", action="store_true")
173
+
174
+ parser = argparse.ArgumentParser()
175
+ subparsers = parser.add_subparsers(dest="pipeline", help="Pipeline to run")
176
+
177
+ # Add the help option to each subparser
178
+ muscle_adipose_tissue_parser = subparsers.add_parser(
179
+ "muscle_adipose_tissue", parents=[base_parser]
180
+ )
181
+ muscle_adipose_tissue_parser.add_argument(
182
+ "--muscle_fat_model", default="abCT_v0.0.1", type=str
183
+ )
184
+
185
+ # Spine
186
+ spine_parser = subparsers.add_parser("spine", parents=[base_parser])
187
+ spine_parser.add_argument("--spine_model", default="ts_spine", type=str)
188
+
189
+ # Spine + muscle + fat
190
+ spine_muscle_adipose_tissue_parser = subparsers.add_parser(
191
+ "spine_muscle_adipose_tissue", parents=[base_parser]
192
+ )
193
+ spine_muscle_adipose_tissue_parser.add_argument(
194
+ "--muscle_fat_model", default="stanford_v0.0.2", type=str
195
+ )
196
+ spine_muscle_adipose_tissue_parser.add_argument(
197
+ "--spine_model", default="ts_spine", type=str
198
+ )
199
+
200
+ liver_spleen_pancreas = subparsers.add_parser(
201
+ "liver_spleen_pancreas", parents=[base_parser]
202
+ )
203
+
204
+ aortic_calcium = subparsers.add_parser("aortic_calcium", parents=[base_parser])
205
+
206
+ contrast_phase_parser = subparsers.add_parser(
207
+ "contrast_phase", parents=[base_parser]
208
+ )
209
+
210
+ hip_parser = subparsers.add_parser("hip", parents=[base_parser])
211
+ hip_parser.add_argument(
212
+ "--hip_model",
213
+ default="ts_hip",
214
+ type=str,
215
+ )
216
+
217
+ # AAA
218
+ aorta_diameter_parser = subparsers.add_parser("aaa", help="aorta diameter", parents=[base_parser])
219
+
220
+ aorta_diameter_parser.add_argument(
221
+ "--aorta_model",
222
+ default="ts_spine",
223
+ type=str,
224
+ help="aorta model to use for inference",
225
+ )
226
+
227
+ aorta_diameter_parser.add_argument(
228
+ "--spine_model",
229
+ default="ts_spine",
230
+ type=str,
231
+ help="spine model to use for inference",
232
+ )
233
+
234
+ all_parser = subparsers.add_parser("all", parents=[base_parser])
235
+ all_parser.add_argument(
236
+ "--muscle_fat_model",
237
+ default="abCT_v0.0.1",
238
+ type=str,
239
+ )
240
+ all_parser.add_argument(
241
+ "--spine_model",
242
+ default="ts_spine",
243
+ type=str,
244
+ )
245
+ all_parser.add_argument(
246
+ "--hip_model",
247
+ default="ts_hip",
248
+ type=str,
249
+ )
250
+ return parser
251
+
252
+
253
+ def main():
254
+ args = argument_parser().parse_args()
255
+ if args.pipeline == "spine_muscle_adipose_tissue":
256
+ process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder)
257
+ elif args.pipeline == "spine":
258
+ process_3d(args, SpinePipelineBuilder)
259
+ elif args.pipeline == "contrast_phase":
260
+ process_3d(args, ContrastPhasePipelineBuilder)
261
+ elif args.pipeline == "liver_spleen_pancreas":
262
+ process_3d(args, LiverSpleenPancreasPipelineBuilder)
263
+ elif args.pipeline == "aortic_calcium":
264
+ process_3d(args, AorticCalciumPipelineBuilder)
265
+ elif args.pipeline == "hip":
266
+ process_3d(args, HipPipelineBuilder)
267
+ elif args.pipeline == "aaa":
268
+ process_3d(args, AAAPipelineBuilder)
269
+ elif args.pipeline == "all":
270
+ process_3d(args, AllPipelineBuilder)
271
+ else:
272
+ raise AssertionError("{} command not supported".format(args.action))
273
+
274
+
275
+ if __name__ == "__main__":
276
+ main()
Comp2Comp-main/bin/C2C-slurm ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import pipes
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ exec_file = sys.argv[0].split("-")[0]
9
+ command = exec_file + " " + " ".join([pipes.quote(s) for s in sys.argv[1:]])
10
+
11
+
12
+ def submit_command(command):
13
+ subprocess.run(command.split(" "), check=True, capture_output=False)
14
+
15
+
16
+ def python_submit(command, node="siena"):
17
+ bash_file = open("./slurm.sh", "w")
18
+ bash_file.write(f"#!/bin/bash\n{command}")
19
+ bash_file.close()
20
+ slurm_output_path = Path("./slurm/")
21
+ slurm_output_path.mkdir(parents=True, exist_ok=True)
22
+
23
+ try:
24
+ if node is None:
25
+ command = "sbatch --ntasks=1 --cpus-per-task=1 --output ./slurm/slurm-%j.out \
26
+ --mem-per-cpu=8G -p gpu --gpus 1 --time=1:00:00 slurm.sh"
27
+ submit_command(command)
28
+ print(f'Submitted the command --- "{command}" --- to slurm.')
29
+ else:
30
+ command = f"sbatch --ntasks=1 --cpus-per-task=1 --output ./slurm/slurm-%j.out \
31
+ --nodelist={node} --mem-per-cpu=8G -p gpu --gpus 1 --time=1:00:00 slurm.sh"
32
+ submit_command(command)
33
+ print(f'Submitted the command --- "{command}" --- to slurm.')
34
+ except subprocess.CalledProcessError:
35
+ if node == None:
36
+ command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --mem=128000 --time=100-00:00:00 slurm.sh "
37
+ submit_command(command)
38
+ print(f'Submitted the command --- "{command}" --- to slurm.')
39
+ else:
40
+ command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=128000 --time=100-00:00:00 slurm.sh"
41
+ submit_command(command)
42
+ print(f'Submitted the command --- "{command}" --- to slurm.')
43
+ os.remove("./slurm.sh")
44
+
45
+
46
+ python_submit(command)
Comp2Comp-main/bin/install.sh ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # ==============================================================================
4
+ # Auto-installation for abCTSeg for Linux and Mac machines.
5
+ # This setup script is adapted from DOSMA:
6
+ # https://github.com/ad12/DOSMA
7
+ # ==============================================================================
8
+
9
+ BIN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
10
+
11
+ ANACONDA_KEYWORD="anaconda"
12
+ ANACONDA_DOWNLOAD_URL="https://www.anaconda.com/distribution/"
13
+ MINICONDA_KEYWORD="miniconda"
14
+
15
+ # FIXME: Update the name.
16
+ ABCT_ENV_NAME="c2c_env"
17
+
18
+ hasAnaconda=0
19
+ updateEnv=0
20
+ updatePath=1
21
+ pythonVersion="3.9"
22
+ cudaVersion=""
23
+
24
+ while [[ $# -gt 0 ]]; do
25
+ key="$1"
26
+ case $key in
27
+ -h|--help)
28
+ echo "Batch evaluation with ss_recon"
29
+ echo ""
30
+ echo "Usage:"
31
+ echo " --python <string> Python version"
32
+ echo " -f, --force Force environment update"
33
+ exit
34
+ ;;
35
+ --python)
36
+ pythonVersion=$2
37
+ shift # past argument
38
+ shift # past value
39
+ ;;
40
+ --cuda)
41
+ cudaVersion=$2
42
+ shift # past argument
43
+ shift # past value
44
+ ;;
45
+ -f|--force)
46
+ updateEnv=1
47
+ shift # past argument
48
+ ;;
49
+ *)
50
+ echo "Unknown option: $key"
51
+ exit 1
52
+ ;;
53
+ esac
54
+ done
55
+
56
+ # Initial setup
57
+ source ~/.bashrc
58
+ currDir=`pwd`
59
+
60
+
61
+ if echo $PATH | grep -q $ANACONDA_KEYWORD; then
62
+ hasAnaconda=1
63
+ echo "Conda found in path"
64
+ fi
65
+
66
+ if echo $PATH | grep -q $MINICONDA_KEYWORD; then
67
+ hasAnaconda=1
68
+ echo "Miniconda found in path"
69
+ fi
70
+
71
+ if [[ $hasAnaconda -eq 0 ]]; then
72
+ echo "Anaconda/Miniconda not installed - install from $ANACONDA_DOWNLOAD_URL"
73
+ openURL $ANACONDA_DOWNLOAD_URL
74
+ exit 125
75
+ fi
76
+
77
+ # Hacky way of finding the conda base directory
78
+ condaPath=`which conda`
79
+ condaPath=`dirname ${condaPath}`
80
+ condaPath=`dirname ${condaPath}`
81
+ # Source conda
82
+ source $condaPath/etc/profile.d/conda.sh
83
+
84
+ # Check if OS is supported
85
+ if [[ "$OSTYPE" != "linux-gnu" && "$OSTYPE" != "darwin"* ]]; then
86
+ echo "Only Linux and MacOS are supported"
87
+ exit 125
88
+ fi
89
+
90
+ # Create Anaconda environment (dosma_env)
91
+ if [[ `conda env list | grep $ABCT_ENV_NAME` ]]; then
92
+ if [[ ${updateEnv} -eq 0 ]]; then
93
+ echo "Environment '${ABCT_ENV_NAME}' is installed. Run 'conda activate ${ABCT_ENV_NAME}' to get started."
94
+ exit 0
95
+ else
96
+ conda env remove -n $ABCT_ENV_NAME
97
+ conda create -y -n $ABCT_ENV_NAME python=3.9
98
+ fi
99
+ else
100
+ conda create -y -n $ABCT_ENV_NAME python=3.9
101
+ fi
102
+
103
+ conda activate $ABCT_ENV_NAME
104
+
105
+ # Install tensorflow and keras
106
+ # https://www.tensorflow.org/install/source#gpu
107
+ # pip install tensorflow
108
+
109
+ # Install pytorch
110
+ # FIXME: PyTorch has to be installed with pip to respect setup.py files from nn UNet
111
+ # pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
112
+ # if [[ "$OSTYPE" == "darwin"* ]]; then
113
+ # # Mac
114
+ # if [[ $cudaVersion != "" ]]; then
115
+ # # CPU
116
+ # echo "Cannot install PyTorch with CUDA support on Mac"
117
+ # exit 1
118
+ # fi
119
+ # conda install -y pytorch torchvision torchaudio -c pytorch
120
+ # else
121
+ # # Linux
122
+ # if [[ $cudaVersion == "" ]]; then
123
+ # cudatoolkit="cpuonly"
124
+ # else
125
+ # cudatoolkit="cudatoolkit=${cudaVersion}"
126
+ # fi
127
+ # conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 $cudatoolkit -c pytorch
128
+ # fi
129
+
130
+ # Install detectron2
131
+ # FIXME: Remove dependency on detectron2
132
+ #pip3 install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html
133
+
134
+ # Install totalSegmentor
135
+ # FIXME: Add this to the setup.py file
136
+ # pip3 install git+https://github.com/StanfordMIMI/TotalSegmentator.git
137
+
138
+ # cd $currDir/..
139
+ # echo $currDir
140
+ # exit 1
141
+
142
+ pip install -e . --no-cache-dir
143
+
144
+ echo ""
145
+ echo ""
146
+ echo "Run 'conda activate ${ABCT_ENV_NAME}' to get started."
Comp2Comp-main/comp2comp/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .utils.env import setup_environment
2
+
3
+ setup_environment()
4
+
5
+
6
+ # This line will be programatically read/write by setup.py.
7
+ # Leave them at the bottom of this file and don't touch them.
8
+ __version__ = "0.0.1"
Comp2Comp-main/comp2comp/aaa/aaa.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import operator
3
+ import os
4
+ import zipfile
5
+ from pathlib import Path
6
+ from time import time
7
+ from tkinter import Tcl
8
+ from typing import Union
9
+
10
+ import cv2
11
+ import matplotlib.pyplot as plt
12
+ import moviepy.video.io.ImageSequenceClip
13
+ import nibabel as nib
14
+ import numpy as np
15
+ import pandas as pd
16
+ import pydicom
17
+ import wget
18
+ from totalsegmentator.libs import nostdout
19
+
20
+ from comp2comp.inference_class_base import InferenceClass
21
+
22
+
23
+ class AortaSegmentation(InferenceClass):
24
+ """Spine segmentation."""
25
+
26
+ def __init__(self, save=True):
27
+ super().__init__()
28
+ self.model_name = "totalsegmentator"
29
+ self.save_segmentations = save
30
+
31
+ def __call__(self, inference_pipeline):
32
+ # inference_pipeline.dicom_series_path = self.input_path
33
+ self.output_dir = inference_pipeline.output_dir
34
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
35
+ if not os.path.exists(self.output_dir_segmentations):
36
+ os.makedirs(self.output_dir_segmentations)
37
+
38
+ self.model_dir = inference_pipeline.model_dir
39
+
40
+ seg, mv = self.spine_seg(
41
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
42
+ self.output_dir_segmentations + "spine.nii.gz",
43
+ inference_pipeline.model_dir,
44
+ )
45
+
46
+ seg = seg.get_fdata()
47
+ medical_volume = mv.get_fdata()
48
+
49
+ axial_masks = []
50
+ ct_image = []
51
+
52
+ for i in range(seg.shape[2]):
53
+ axial_masks.append(seg[:, :, i])
54
+
55
+ for i in range(medical_volume.shape[2]):
56
+ ct_image.append(medical_volume[:, :, i])
57
+
58
+ # Save input axial slices to pipeline
59
+ inference_pipeline.ct_image = ct_image
60
+
61
+ # Save aorta masks to pipeline
62
+ inference_pipeline.axial_masks = axial_masks
63
+
64
+ return {}
65
+
66
+ def setup_nnunet_c2c(self, model_dir: Union[str, Path]):
67
+ """Adapted from TotalSegmentator."""
68
+
69
+ model_dir = Path(model_dir)
70
+ config_dir = model_dir / Path("." + self.model_name)
71
+ (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir(
72
+ exist_ok=True, parents=True
73
+ )
74
+ (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True)
75
+ weights_dir = config_dir / "nnunet/results"
76
+ self.weights_dir = weights_dir
77
+
78
+ os.environ["nnUNet_raw_data_base"] = str(
79
+ weights_dir
80
+ ) # not needed, just needs to be an existing directory
81
+ os.environ["nnUNet_preprocessed"] = str(
82
+ weights_dir
83
+ ) # not needed, just needs to be an existing directory
84
+ os.environ["RESULTS_FOLDER"] = str(weights_dir)
85
+
86
+ def download_spine_model(self, model_dir: Union[str, Path]):
87
+ download_dir = Path(
88
+ os.path.join(
89
+ self.weights_dir,
90
+ "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1",
91
+ )
92
+ )
93
+ print(download_dir)
94
+ fold_0_path = download_dir / "fold_0"
95
+ if not os.path.exists(fold_0_path):
96
+ download_dir.mkdir(parents=True, exist_ok=True)
97
+ wget.download(
98
+ "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip",
99
+ out=os.path.join(download_dir, "fold_0.zip"),
100
+ )
101
+ with zipfile.ZipFile(
102
+ os.path.join(download_dir, "fold_0.zip"), "r"
103
+ ) as zip_ref:
104
+ zip_ref.extractall(download_dir)
105
+ os.remove(os.path.join(download_dir, "fold_0.zip"))
106
+ wget.download(
107
+ "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl",
108
+ out=os.path.join(download_dir, "plans.pkl"),
109
+ )
110
+ print("Spine model downloaded.")
111
+ else:
112
+ print("Spine model already downloaded.")
113
+
114
+ def spine_seg(
115
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
116
+ ):
117
+ """Run spine segmentation.
118
+
119
+ Args:
120
+ input_path (Union[str, Path]): Input path.
121
+ output_path (Union[str, Path]): Output path.
122
+ """
123
+
124
+ print("Segmenting spine...")
125
+ st = time()
126
+ os.environ["SCRATCH"] = self.model_dir
127
+
128
+ print(self.model_dir)
129
+
130
+ # Setup nnunet
131
+ model = "3d_fullres"
132
+ folds = [0]
133
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
134
+ crop_path = None
135
+ task_id = [253]
136
+
137
+ self.setup_nnunet_c2c(model_dir)
138
+ self.download_spine_model(model_dir)
139
+
140
+ from totalsegmentator.nnunet import nnUNet_predict_image
141
+
142
+ with nostdout():
143
+ img, seg = nnUNet_predict_image(
144
+ input_path,
145
+ output_path,
146
+ task_id,
147
+ model=model,
148
+ folds=folds,
149
+ trainer=trainer,
150
+ tta=False,
151
+ multilabel_image=True,
152
+ resample=1.5,
153
+ crop=None,
154
+ crop_path=crop_path,
155
+ task_name="total",
156
+ nora_tag="None",
157
+ preview=False,
158
+ nr_threads_resampling=1,
159
+ nr_threads_saving=6,
160
+ quiet=False,
161
+ verbose=False,
162
+ test=0,
163
+ )
164
+ end = time()
165
+
166
+ # Log total time for spine segmentation
167
+ print(f"Total time for spine segmentation: {end-st:.2f}s.")
168
+
169
+ seg_data = seg.get_fdata()
170
+ seg = nib.Nifti1Image(seg_data, seg.affine, seg.header)
171
+
172
+ return seg, img
173
+
174
+
175
+ class AortaDiameter(InferenceClass):
176
+ def __init__(self):
177
+ super().__init__()
178
+
179
+ def normalize_img(self, img: np.ndarray) -> np.ndarray:
180
+ """Normalize the image.
181
+ Args:
182
+ img (np.ndarray): Input image.
183
+ Returns:
184
+ np.ndarray: Normalized image.
185
+ """
186
+ return (img - img.min()) / (img.max() - img.min())
187
+
188
+ def __call__(self, inference_pipeline):
189
+ axial_masks = (
190
+ inference_pipeline.axial_masks
191
+ ) # list of 2D numpy arrays of shape (512, 512)
192
+ ct_img = (
193
+ inference_pipeline.ct_image
194
+ ) # 3D numpy array of shape (512, 512, num_axial_slices)
195
+
196
+ # image output directory
197
+ output_dir = inference_pipeline.output_dir
198
+ output_dir_slices = os.path.join(output_dir, "images/slices/")
199
+ if not os.path.exists(output_dir_slices):
200
+ os.makedirs(output_dir_slices)
201
+
202
+ output_dir = inference_pipeline.output_dir
203
+ output_dir_summary = os.path.join(output_dir, "images/summary/")
204
+ if not os.path.exists(output_dir_summary):
205
+ os.makedirs(output_dir_summary)
206
+
207
+ DICOM_PATH = inference_pipeline.dicom_series_path
208
+ dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0])
209
+
210
+ dicom.PhotometricInterpretation = "YBR_FULL"
211
+ pixel_conversion = dicom.PixelSpacing
212
+ print("Pixel conversion: " + str(pixel_conversion))
213
+ RATIO_PIXEL_TO_MM = pixel_conversion[0]
214
+
215
+ SLICE_COUNT = dicom["InstanceNumber"].value
216
+ print(SLICE_COUNT)
217
+
218
+ SLICE_COUNT = len(ct_img)
219
+ diameterDict = {}
220
+
221
+ for i in range(len(ct_img)):
222
+ mask = axial_masks[i].astype("uint8")
223
+
224
+ img = ct_img[i]
225
+
226
+ img = np.clip(img, -300, 1800)
227
+ img = self.normalize_img(img) * 255.0
228
+ img = img.reshape((img.shape[0], img.shape[1], 1))
229
+ img = np.tile(img, (1, 1, 3))
230
+
231
+ contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
232
+
233
+ if len(contours) != 0:
234
+ areas = [cv2.contourArea(c) for c in contours]
235
+ sorted_areas = np.sort(areas)
236
+
237
+ areas = [cv2.contourArea(c) for c in contours]
238
+ sorted_areas = np.sort(areas)
239
+ contours = contours[areas.index(sorted_areas[-1])]
240
+
241
+ img.copy()
242
+
243
+ back = img.copy()
244
+ cv2.drawContours(back, [contours], 0, (0, 255, 0), -1)
245
+
246
+ alpha = 0.25
247
+ img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0)
248
+
249
+ ellipse = cv2.fitEllipse(contours)
250
+ (xc, yc), (d1, d2), angle = ellipse
251
+
252
+ cv2.ellipse(img, ellipse, (0, 255, 0), 1)
253
+
254
+ xc, yc = ellipse[0]
255
+ cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1)
256
+
257
+ rmajor = max(d1, d2) / 2
258
+ rminor = min(d1, d2) / 2
259
+
260
+ ### Draw major axes
261
+
262
+ if angle > 90:
263
+ angle = angle - 90
264
+ else:
265
+ angle = angle + 90
266
+ print(angle)
267
+ xtop = xc + math.cos(math.radians(angle)) * rmajor
268
+ ytop = yc + math.sin(math.radians(angle)) * rmajor
269
+ xbot = xc + math.cos(math.radians(angle + 180)) * rmajor
270
+ ybot = yc + math.sin(math.radians(angle + 180)) * rmajor
271
+ cv2.line(
272
+ img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3
273
+ )
274
+
275
+ ### Draw minor axes
276
+
277
+ if angle > 90:
278
+ angle = angle - 90
279
+ else:
280
+ angle = angle + 90
281
+ print(angle)
282
+ x1 = xc + math.cos(math.radians(angle)) * rminor
283
+ y1 = yc + math.sin(math.radians(angle)) * rminor
284
+ x2 = xc + math.cos(math.radians(angle + 180)) * rminor
285
+ y2 = yc + math.sin(math.radians(angle + 180)) * rminor
286
+ cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3)
287
+
288
+ # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 )
289
+ pixel_length = rminor * 2
290
+
291
+ print("Pixel_length_minor: " + str(pixel_length))
292
+
293
+ area_px = cv2.contourArea(contours)
294
+ area_mm = round(area_px * RATIO_PIXEL_TO_MM)
295
+ area_cm = area_mm / 10
296
+
297
+ diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM)
298
+ diameter_cm = diameter_mm / 10
299
+
300
+ diameterDict[(SLICE_COUNT - (i))] = diameter_cm
301
+
302
+ img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
303
+
304
+ h, w, c = img.shape
305
+ lbls = [
306
+ "Area (mm): " + str(area_mm) + "mm",
307
+ "Area (cm): " + str(area_cm) + "cm",
308
+ "Diameter (mm): " + str(diameter_mm) + "mm",
309
+ "Diameter (cm): " + str(diameter_cm) + "cm",
310
+ "Slice: " + str(SLICE_COUNT - (i)),
311
+ ]
312
+ font = cv2.FONT_HERSHEY_SIMPLEX
313
+
314
+ scale = 0.03
315
+ fontScale = min(w, h) / (25 / scale)
316
+
317
+ cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2)
318
+
319
+ cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2)
320
+
321
+ cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2)
322
+
323
+ cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2)
324
+
325
+ cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2)
326
+
327
+ cv2.imwrite(
328
+ output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img
329
+ )
330
+
331
+ plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b")
332
+
333
+ plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$")
334
+
335
+ plt.xlabel("Slice Number")
336
+
337
+ plt.ylabel("Diameter Measurement (cm)")
338
+ plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500)
339
+
340
+ print(diameterDict)
341
+ print(max(diameterDict.items(), key=operator.itemgetter(1))[0])
342
+ print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]])
343
+
344
+ inference_pipeline.max_diameter = diameterDict[
345
+ max(diameterDict.items(), key=operator.itemgetter(1))[0]
346
+ ]
347
+
348
+ img = ct_img[
349
+ SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0])
350
+ ]
351
+ img = np.clip(img, -300, 1800)
352
+ img = self.normalize_img(img) * 255.0
353
+ img = img.reshape((img.shape[0], img.shape[1], 1))
354
+ img2 = np.tile(img, (1, 1, 3))
355
+ img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE)
356
+
357
+ img1 = cv2.imread(
358
+ output_dir_slices
359
+ + "slice"
360
+ + str(max(diameterDict.items(), key=operator.itemgetter(1))[0])
361
+ + ".png"
362
+ )
363
+
364
+ border_size = 3
365
+ img1 = cv2.copyMakeBorder(
366
+ img1,
367
+ top=border_size,
368
+ bottom=border_size,
369
+ left=border_size,
370
+ right=border_size,
371
+ borderType=cv2.BORDER_CONSTANT,
372
+ value=[0, 244, 0],
373
+ )
374
+ img2 = cv2.copyMakeBorder(
375
+ img2,
376
+ top=border_size,
377
+ bottom=border_size,
378
+ left=border_size,
379
+ right=border_size,
380
+ borderType=cv2.BORDER_CONSTANT,
381
+ value=[244, 0, 0],
382
+ )
383
+
384
+ vis = np.concatenate((img2, img1), axis=1)
385
+ cv2.imwrite(output_dir_summary + "out.png", vis)
386
+
387
+ image_folder = output_dir_slices
388
+ fps = 20
389
+ image_files = [
390
+ os.path.join(image_folder, img)
391
+ for img in Tcl().call("lsort", "-dict", os.listdir(image_folder))
392
+ if img.endswith(".png")
393
+ ]
394
+ clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(
395
+ image_files, fps=fps
396
+ )
397
+ clip.write_videofile(output_dir_summary + "aaa.mp4")
398
+
399
+ return {}
400
+
401
+
402
+ class AortaMetricsSaver(InferenceClass):
403
+ """Save metrics to a CSV file."""
404
+
405
+ def __init__(self):
406
+ super().__init__()
407
+
408
+ def __call__(self, inference_pipeline):
409
+ """Save metrics to a CSV file."""
410
+ self.max_diameter = inference_pipeline.max_diameter
411
+ self.dicom_series_path = inference_pipeline.dicom_series_path
412
+ self.output_dir = inference_pipeline.output_dir
413
+ self.csv_output_dir = os.path.join(self.output_dir, "metrics")
414
+ if not os.path.exists(self.csv_output_dir):
415
+ os.makedirs(self.csv_output_dir, exist_ok=True)
416
+ self.save_results()
417
+ return {}
418
+
419
+ def save_results(self):
420
+ """Save results to a CSV file."""
421
+ _, filename = os.path.split(self.dicom_series_path)
422
+ data = [[filename, str(self.max_diameter)]]
423
+ df = pd.DataFrame(data, columns=["Filename", "Max Diameter"])
424
+ df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False)
Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Thu Apr 20 20:36:05 2023
5
+
6
+ @author: maltejensen
7
+ """
8
+ import os
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Union
12
+
13
+ import numpy as np
14
+ from scipy import ndimage
15
+ from totalsegmentator.libs import (
16
+ download_pretrained_weights,
17
+ nostdout,
18
+ setup_nnunet,
19
+ )
20
+
21
+ from comp2comp.inference_class_base import InferenceClass
22
+
23
+
24
+ class AortaSegmentation(InferenceClass):
25
+ """Aorta segmentation."""
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+ # self.input_path = input_path
30
+
31
+ def __call__(self, inference_pipeline):
32
+ # inference_pipeline.dicom_series_path = self.input_path
33
+ self.output_dir = inference_pipeline.output_dir
34
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
35
+ inference_pipeline.output_dir_segmentations = os.path.join(
36
+ self.output_dir, "segmentations/"
37
+ )
38
+
39
+ if not os.path.exists(self.output_dir_segmentations):
40
+ os.makedirs(self.output_dir_segmentations)
41
+
42
+ self.model_dir = inference_pipeline.model_dir
43
+
44
+ mv, seg = self.aorta_seg(
45
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
46
+ self.output_dir_segmentations + "organs.nii.gz",
47
+ inference_pipeline.model_dir,
48
+ )
49
+ # the medical volume is already set by the spine segmentation model
50
+ # the toCanonical methods looks for "segmentation", so it's overridden
51
+ inference_pipeline.spine_segmentation = inference_pipeline.segmentation
52
+ inference_pipeline.segmentation = seg
53
+
54
+ return {}
55
+
56
+ def aorta_seg(
57
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
58
+ ):
59
+ """Run organ segmentation.
60
+
61
+ Args:
62
+ input_path (Union[str, Path]): Input path.
63
+ output_path (Union[str, Path]): Output path.
64
+ """
65
+
66
+ print("Segmenting aorta...")
67
+ st = time.time()
68
+ os.environ["SCRATCH"] = self.model_dir
69
+
70
+ # Setup nnunet
71
+ model = "3d_fullres"
72
+ folds = [0]
73
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
74
+ crop_path = None
75
+ task_id = [251]
76
+
77
+ setup_nnunet()
78
+ download_pretrained_weights(task_id[0])
79
+
80
+ from totalsegmentator.nnunet import nnUNet_predict_image
81
+
82
+ with nostdout():
83
+ seg, mvs = nnUNet_predict_image(
84
+ input_path,
85
+ output_path,
86
+ task_id,
87
+ model=model,
88
+ folds=folds,
89
+ trainer=trainer,
90
+ tta=False,
91
+ multilabel_image=True,
92
+ resample=1.5,
93
+ crop=None,
94
+ crop_path=crop_path,
95
+ task_name="total",
96
+ nora_tag="None",
97
+ preview=False,
98
+ nr_threads_resampling=1,
99
+ nr_threads_saving=6,
100
+ quiet=False,
101
+ verbose=True,
102
+ test=0,
103
+ )
104
+ end = time.time()
105
+
106
+ # Log total time for spine segmentation
107
+ print(f"Total time for aorta segmentation: {end-st:.2f}s.")
108
+
109
+ return seg, mvs
110
+
111
+
112
+ class AorticCalciumSegmentation(InferenceClass):
113
+ """Segmentaiton of aortic calcium"""
114
+
115
+ def __init__(self):
116
+ super().__init__()
117
+
118
+ def __call__(self, inference_pipeline):
119
+ ct = inference_pipeline.medical_volume.get_fdata()
120
+ aorta_mask = inference_pipeline.segmentation.get_fdata() == 7
121
+ spine_mask = inference_pipeline.spine_segmentation.get_fdata() > 0
122
+
123
+ inference_pipeline.calc_mask = self.detectCalcifications(
124
+ ct, aorta_mask, exclude_mask=spine_mask, remove_size=3
125
+ )
126
+
127
+ self.output_dir = inference_pipeline.output_dir
128
+ self.output_dir_images_organs = os.path.join(self.output_dir, "images/")
129
+ inference_pipeline.output_dir_images_organs = self.output_dir_images_organs
130
+
131
+ if not os.path.exists(self.output_dir_images_organs):
132
+ os.makedirs(self.output_dir_images_organs)
133
+
134
+ # np.save(os.path.join(self.output_dir_images_organs, 'ct.npy'), ct)
135
+ # np.save(os.path.join(self.output_dir_images_organs, "aorta_mask.npy"), aorta_mask)
136
+ # np.save(os.path.join(self.output_dir_images_organs, "spine_mask.npy"), spine_mask)
137
+
138
+ # np.save(
139
+ # os.path.join(self.output_dir_images_organs, "calcium_mask.npy"),
140
+ # inference_pipeline.calc_mask,
141
+ # )
142
+ # np.save(
143
+ # os.path.join(self.output_dir_images_organs, "ct_scan.npy"),
144
+ # inference_pipeline.medical_volume.get_fdata(),
145
+ # )
146
+
147
+ return {}
148
+
149
+ def detectCalcifications(
150
+ self,
151
+ ct,
152
+ aorta_mask,
153
+ exclude_mask=None,
154
+ return_dilated_mask=False,
155
+ dilation=(3, 1),
156
+ dilation_iteration=4,
157
+ return_dilated_exclude=False,
158
+ dilation_exclude_mask=(3, 1),
159
+ dilation_iteration_exclude=3,
160
+ show_time=False,
161
+ num_std=3,
162
+ remove_size=None,
163
+ verbose=False,
164
+ exclude_center_aorta=True,
165
+ return_eroded_aorta=False,
166
+ aorta_erode_iteration=6,
167
+ ):
168
+ """
169
+ Function that takes in a CT image and aorta segmentation (and optionally volumes to use
170
+ for exclusion of the segmentations), And returns a mask of the segmented calcifications
171
+ (and optionally other volumes). The calcium threshold is adapative and uses the median
172
+ of the CT points inside the aorta together with one standard devidation to the left, as
173
+ this is more stable. The num_std is multiplied with the distance between the median
174
+ and the one standard deviation mark, and can be used to control the threshold.
175
+
176
+ Args:
177
+ ct (array): CT image.
178
+ aorta_mask (array): Mask of the aorta.
179
+ exclude_mask (array, optional):
180
+ Mask for structures to exclude e.g. spine. Defaults to None.
181
+ return_dilated_mask (bool, optional):
182
+ Return the dilated aorta mask. Defaults to False.
183
+ dilation (list, optional):
184
+ Structuring element for aorta dilation. Defaults to (3,1).
185
+ dilation_iteration (int, optional):
186
+ Number of iterations for the strcturing element. Defaults to 4.
187
+ return_dilated_exclude (bool, optional):
188
+ Return the dilated exclusio mask. Defaults to False.
189
+ dilation_exclude_mask (list, optional):
190
+ Structering element for the exclusio mask. Defaults to (3,1).
191
+ dilation_iteration_exclude (int, optional):
192
+ Number of iterations for the strcturing element. Defaults to 3.
193
+ show_time (bool, optional):
194
+ Show time for each operation. Defaults to False.
195
+ num_std (float, optional):
196
+ How many standard deviations out the threshold will be set at. Defaults to 3.
197
+ remove_size (int, optional):
198
+ Remove foci under a certain size. Warning: quite slow. Defaults to None.
199
+ verbose (bool, optional):
200
+ Give verbose feedback on operations. Defaults to False.
201
+ exclude_center_aorta (bool, optional):
202
+ Use eroded aorta to exclude center of the aorta. Defaults to True.
203
+ return_eroded_aorta (bool, optional):
204
+ Return the eroded center aorta. Defaults to False.
205
+ aorta_erode_iteration (int, optional):
206
+ Number of iterations for the strcturing element. Defaults to 6.
207
+
208
+ Returns:
209
+ results: array of only the mask is returned, or dict if other volumes are also returned.
210
+
211
+ """
212
+
213
+ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation):
214
+ """
215
+ Perform the dilation on the smallest slice that will fit the
216
+ segmentation
217
+ """
218
+ margin = 2 if num_iteration is None else num_iteration + 1
219
+
220
+ x_idx = np.where(input_mask.sum(axis=(1, 2)))[0]
221
+ x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin
222
+ y_idx = np.where(input_mask.sum(axis=(0, 2)))[0]
223
+ y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin
224
+
225
+ if operation == "dilate":
226
+ mask_slice = ndimage.binary_dilation(
227
+ input_mask[x_start:x_end, y_start:y_end, :], structure=struct
228
+ ).astype(np.int8)
229
+ elif operation == "erode":
230
+ mask_slice = ndimage.binary_erosion(
231
+ input_mask[x_start:x_end, y_start:y_end, :], structure=struct
232
+ ).astype(np.int8)
233
+
234
+ output_mask = input_mask.copy()
235
+
236
+ output_mask[x_start:x_end, y_start:y_end, :] = mask_slice
237
+
238
+ return output_mask
239
+
240
+ # remove parts that are not the abdominal aorta
241
+ labelled_aorta, num_classes = ndimage.label(aorta_mask)
242
+ if num_classes > 1:
243
+ if verbose:
244
+ print("Removing {} parts".format(num_classes - 1))
245
+
246
+ aorta_vols = []
247
+
248
+ for i in range(1, num_classes + 1):
249
+ aorta_vols.append((labelled_aorta == i).sum())
250
+
251
+ biggest_idx = np.argmax(aorta_vols) + 1
252
+ aorta_mask[labelled_aorta != biggest_idx] = 0
253
+
254
+ # Get aortic CT point to set adaptive threshold
255
+ aorta_ct_points = ct[aorta_mask == 1]
256
+
257
+ # equal to one standard deviation to the left of the curve
258
+ quant = 0.158
259
+ quantile_median_dist = np.median(aorta_ct_points) - np.quantile(
260
+ aorta_ct_points, q=quant
261
+ )
262
+ calc_thres = np.median(aorta_ct_points) + quantile_median_dist * num_std
263
+
264
+ t0 = time.time()
265
+
266
+ if dilation is not None:
267
+ struct = ndimage.generate_binary_structure(*dilation)
268
+ if dilation_iteration is not None:
269
+ struct = ndimage.iterate_structure(struct, dilation_iteration)
270
+ aorta_dilated = slicedDilationOrErosion(
271
+ aorta_mask,
272
+ struct=struct,
273
+ num_iteration=dilation_iteration,
274
+ operation="dilate",
275
+ )
276
+
277
+ if show_time:
278
+ print("dilation mask time: {:.2f}".format(time.time() - t0))
279
+
280
+ t0 = time.time()
281
+ calc_mask = np.logical_and(aorta_dilated == 1, ct >= calc_thres)
282
+ if show_time:
283
+ print("find calc time: {:.2f}".format(time.time() - t0))
284
+
285
+ if exclude_center_aorta:
286
+ t0 = time.time()
287
+
288
+ struct = ndimage.generate_binary_structure(3, 1)
289
+ struct = ndimage.iterate_structure(struct, aorta_erode_iteration)
290
+
291
+ aorta_eroded = slicedDilationOrErosion(
292
+ aorta_mask,
293
+ struct=struct,
294
+ num_iteration=aorta_erode_iteration,
295
+ operation="erode",
296
+ )
297
+ calc_mask = calc_mask * (aorta_eroded == 0)
298
+ if show_time:
299
+ print("exclude center aorta time: {:.2f} sec".format(time.time() - t0))
300
+
301
+ t0 = time.time()
302
+ if exclude_mask is not None:
303
+ if dilation_exclude_mask is not None:
304
+ struct_exclude = ndimage.generate_binary_structure(
305
+ *dilation_exclude_mask
306
+ )
307
+ if dilation_iteration_exclude is not None:
308
+ struct_exclude = ndimage.iterate_structure(
309
+ struct_exclude, dilation_iteration_exclude
310
+ )
311
+
312
+ exclude_mask = slicedDilationOrErosion(
313
+ exclude_mask,
314
+ struct=struct_exclude,
315
+ num_iteration=dilation_iteration_exclude,
316
+ operation="dilate",
317
+ )
318
+
319
+ if show_time:
320
+ print("exclude dilation time: {:.2f}".format(time.time() - t0))
321
+
322
+ t0 = time.time()
323
+ calc_mask = calc_mask * (exclude_mask == 0)
324
+ if show_time:
325
+ print("exclude time: {:.2f}".format(time.time() - t0))
326
+
327
+ if remove_size is not None:
328
+ t0 = time.time()
329
+
330
+ labels, num_features = ndimage.label(calc_mask)
331
+
332
+ counter = 0
333
+ for n in range(1, num_features + 1):
334
+ idx_tmp = labels == n
335
+ if idx_tmp.sum() <= remove_size:
336
+ calc_mask[idx_tmp] = 0
337
+ counter += 1
338
+
339
+ if show_time:
340
+ print("Size exclusion time: {:.1f} sec".format(time.time() - t0))
341
+ if verbose:
342
+ print("Excluded {} foci under {}".format(counter, remove_size))
343
+
344
+ if not all([return_dilated_mask, return_dilated_exclude]):
345
+ return calc_mask.astype(np.int8)
346
+ else:
347
+ results = {}
348
+ results["calc_mask"] = calc_mask.astype(np.int8)
349
+ if return_dilated_mask:
350
+ results["dilated_mask"] = aorta_dilated
351
+ if return_dilated_exclude:
352
+ results["dilated_exclude"] = exclude_mask
353
+ if return_eroded_aorta:
354
+ results["aorta_eroded"] = aorta_eroded
355
+
356
+ results["threshold"] = calc_thres
357
+
358
+ return results
359
+
360
+
361
+ class AorticCalciumMetrics(InferenceClass):
362
+ """Calculate metrics for the aortic calcifications"""
363
+
364
+ def __init__(self):
365
+ super().__init__()
366
+
367
+ def __call__(self, inference_pipeline):
368
+ calc_mask = inference_pipeline.calc_mask
369
+
370
+ inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[
371
+ "pixdim"
372
+ ][1:4]
373
+ # divided with 10 to get in cm
374
+ inference_pipeline.vol_per_pixel = np.prod(inference_pipeline.pix_dims / 10)
375
+
376
+ # count statistics for individual calcifications
377
+ labelled_calc, num_lesions = ndimage.label(calc_mask)
378
+
379
+ metrics = {
380
+ "volume": [],
381
+ "mean_hu": [],
382
+ "median_hu": [],
383
+ "max_hu": [],
384
+ }
385
+
386
+ ct = inference_pipeline.medical_volume.get_fdata()
387
+
388
+ for i in range(1, num_lesions + 1):
389
+ tmp_mask = labelled_calc == i
390
+
391
+ tmp_ct_vals = ct[tmp_mask]
392
+
393
+ metrics["volume"].append(
394
+ len(tmp_ct_vals) * inference_pipeline.vol_per_pixel
395
+ )
396
+ metrics["mean_hu"].append(np.mean(tmp_ct_vals))
397
+ metrics["median_hu"].append(np.median(tmp_ct_vals))
398
+ metrics["max_hu"].append(np.max(tmp_ct_vals))
399
+
400
+ # Volume of calcificaitons
401
+ calc_vol = np.sum(metrics["volume"])
402
+ metrics["volume_total"] = calc_vol
403
+
404
+ metrics["num_calc"] = len(metrics["volume"])
405
+
406
+ inference_pipeline.metrics = metrics
407
+
408
+ return {}
Comp2Comp-main/comp2comp/aortic_calcium/aortic_calcium_visualization.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ from comp2comp.inference_class_base import InferenceClass
6
+
7
+
8
+ class AorticCalciumVisualizer(InferenceClass):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def __call__(self, inference_pipeline):
13
+ self.output_dir = inference_pipeline.output_dir
14
+ self.output_dir_images_organs = os.path.join(self.output_dir, "images/")
15
+ inference_pipeline.output_dir_images_organs = self.output_dir_images_organs
16
+
17
+ if not os.path.exists(self.output_dir_images_organs):
18
+ os.makedirs(self.output_dir_images_organs)
19
+
20
+ return {}
21
+
22
+
23
+ class AorticCalciumPrinter(InferenceClass):
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def __call__(self, inference_pipeline):
28
+ metrics = inference_pipeline.metrics
29
+
30
+ inference_pipeline.csv_output_dir = os.path.join(
31
+ inference_pipeline.output_dir, "metrics"
32
+ )
33
+ os.makedirs(inference_pipeline.csv_output_dir, exist_ok=True)
34
+
35
+ with open(
36
+ os.path.join(inference_pipeline.csv_output_dir, "aortic_calcification.csv"),
37
+ "w",
38
+ ) as f:
39
+ f.write("Volume (cm^3),Mean HU,Median HU,Max HU\n")
40
+ for vol, mean, median, max in zip(
41
+ metrics["volume"],
42
+ metrics["mean_hu"],
43
+ metrics["median_hu"],
44
+ metrics["max_hu"],
45
+ ):
46
+ f.write("{},{:.1f},{:.1f},{:.1f}\n".format(vol, mean, median, max))
47
+
48
+ with open(
49
+ os.path.join(
50
+ inference_pipeline.csv_output_dir, "aortic_calcification_total.csv"
51
+ ),
52
+ "w",
53
+ ) as f:
54
+ f.write("Total number,{}\n".format(metrics["num_calc"]))
55
+ f.write("Total volume (cm^3),{}\n".format(metrics["volume_total"]))
56
+
57
+ distance = 25
58
+ print("\n")
59
+ if metrics["num_calc"] == 0:
60
+ print("No aortic calcifications were found.")
61
+ else:
62
+ print("Statistics on aortic calcifications:")
63
+ print("{:<{}}{}".format("Total number:", distance, metrics["num_calc"]))
64
+ print(
65
+ "{:<{}}{:.3f}".format(
66
+ "Total volume (cm³):", distance, metrics["volume_total"]
67
+ )
68
+ )
69
+ print(
70
+ "{:<{}}{:.1f}+/-{:.1f}".format(
71
+ "Mean HU:",
72
+ distance,
73
+ np.mean(metrics["mean_hu"]),
74
+ np.std(metrics["mean_hu"]),
75
+ )
76
+ )
77
+ print(
78
+ "{:<{}}{:.1f}+/-{:.1f}".format(
79
+ "Median HU:",
80
+ distance,
81
+ np.mean(metrics["median_hu"]),
82
+ np.std(metrics["median_hu"]),
83
+ )
84
+ )
85
+ print(
86
+ "{:<{}}{:.1f}+/-{:.1f}".format(
87
+ "Max HU:",
88
+ distance,
89
+ np.mean(metrics["max_hu"]),
90
+ np.std(metrics["max_hu"]),
91
+ )
92
+ )
93
+ print(
94
+ "{:<{}}{:.3f}+/-{:.3f}".format(
95
+ "Mean volume (cm³):",
96
+ distance,
97
+ np.mean(metrics["volume"]),
98
+ np.std(metrics["volume"]),
99
+ )
100
+ )
101
+ print(
102
+ "{:<{}}{:.3f}".format(
103
+ "Median volume (cm³):", distance, np.median(metrics["volume"])
104
+ )
105
+ )
106
+ print(
107
+ "{:<{}}{:.3f}".format(
108
+ "Max volume (cm³):", distance, np.max(metrics["volume"])
109
+ )
110
+ )
111
+ print(
112
+ "{:<{}}{:.3f}".format(
113
+ "Min volume (cm³):", distance, np.min(metrics["volume"])
114
+ )
115
+ )
116
+
117
+ print("\n")
118
+
119
+ return {}
Comp2Comp-main/comp2comp/contrast_phase/contrast_inf.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ import sys
5
+
6
+ import nibabel as nib
7
+ import numpy as np
8
+ import scipy
9
+ import SimpleITK as sitk
10
+ from scipy import ndimage as ndi
11
+
12
+
13
+ def loadNiiToArray(path):
14
+ NiImg = nib.load(path)
15
+ array = np.array(NiImg.dataobj)
16
+ return array
17
+
18
+
19
+ def loadNiiWithSitk(path):
20
+ reader = sitk.ImageFileReader()
21
+ reader.SetImageIO("NiftiImageIO")
22
+ reader.SetFileName(path)
23
+ image = reader.Execute()
24
+ array = sitk.GetArrayFromImage(image)
25
+ return array
26
+
27
+
28
+ def loadNiiImageWithSitk(path):
29
+ reader = sitk.ImageFileReader()
30
+ reader.SetImageIO("NiftiImageIO")
31
+ reader.SetFileName(path)
32
+ image = reader.Execute()
33
+ # invert the image to be compatible with Nibabel
34
+ image = sitk.Flip(image, [False, True, False])
35
+ return image
36
+
37
+
38
+ def keep_masked_values(arr, mask):
39
+ # Get the indices of the non-zero elements in the mask
40
+ mask_indices = np.nonzero(mask)
41
+ # Use the indices to select the corresponding elements from the array
42
+ masked_values = arr[mask_indices]
43
+ # Return the selected elements as a new array
44
+ return masked_values
45
+
46
+
47
+ def get_stats(arr):
48
+ # # Get the indices of the non-zero elements in the array
49
+ # nonzero_indices = np.nonzero(arr)
50
+ # # Use the indices to get the non-zero elements of the array
51
+ # nonzero_elements = arr[nonzero_indices]
52
+
53
+ nonzero_elements = arr
54
+
55
+ # Calculate the stats for the non-zero elements
56
+ max_val = np.max(nonzero_elements)
57
+ min_val = np.min(nonzero_elements)
58
+ mean_val = np.mean(nonzero_elements)
59
+ median_val = np.median(nonzero_elements)
60
+ std_val = np.std(nonzero_elements)
61
+ variance_val = np.var(nonzero_elements)
62
+ return max_val, min_val, mean_val, median_val, std_val, variance_val
63
+
64
+
65
+ def getMaskAnteriorAtrium(mask):
66
+ erasePreAtriumMask = mask.copy()
67
+ for sliceNum in range(mask.shape[-1]):
68
+ mask2D = mask[:, :, sliceNum]
69
+ itemindex = np.where(mask2D == 1)
70
+ if itemindex[0].size > 0:
71
+ row = itemindex[0][0]
72
+ erasePreAtriumMask[:, :, sliceNum][:row, :] = 1
73
+ return erasePreAtriumMask
74
+
75
+
76
+ """
77
+ Function from
78
+ https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays/46314485#46314485
79
+ """
80
+
81
+
82
+ def fill_hull(image):
83
+ points = np.transpose(np.where(image))
84
+ hull = scipy.spatial.ConvexHull(points)
85
+ deln = scipy.spatial.Delaunay(points[hull.vertices])
86
+ idx = np.stack(np.indices(image.shape), axis=-1)
87
+ out_idx = np.nonzero(deln.find_simplex(idx) + 1)
88
+ out_img = np.zeros(image.shape)
89
+ out_img[out_idx] = 1
90
+ return out_img
91
+
92
+
93
+ def getClassBinaryMask(TSOutArray, classNum):
94
+ binaryMask = np.zeros(TSOutArray.shape)
95
+ binaryMask[TSOutArray == classNum] = 1
96
+ return binaryMask
97
+
98
+
99
+ def loadNiftis(TSNiftiPath, imageNiftiPath):
100
+ TSArray = loadNiiToArray(TSNiftiPath)
101
+ scanArray = loadNiiToArray(imageNiftiPath)
102
+ return TSArray, scanArray
103
+
104
+
105
+ # function to select one slice from 3D volume of SimpleITK image
106
+ def selectSlice(scanImage, zslice):
107
+ size = list(scanImage.GetSize())
108
+ size[2] = 0
109
+ index = [0, 0, zslice]
110
+
111
+ Extractor = sitk.ExtractImageFilter()
112
+ Extractor.SetSize(size)
113
+ Extractor.SetIndex(index)
114
+
115
+ sliceImage = Extractor.Execute(scanImage)
116
+ return sliceImage
117
+
118
+
119
+ # function to apply windowing
120
+ def windowing(sliceImage, center=400, width=400):
121
+ windowMinimum = center - (width / 2)
122
+ windowMaximum = center + (width / 2)
123
+ img_255 = sitk.Cast(
124
+ sitk.IntensityWindowing(
125
+ sliceImage,
126
+ windowMinimum=-windowMinimum,
127
+ windowMaximum=windowMaximum,
128
+ outputMinimum=0.0,
129
+ outputMaximum=255.0,
130
+ ),
131
+ sitk.sitkUInt8,
132
+ )
133
+ return img_255
134
+
135
+
136
+ def selectSampleSlice(kidneyLMask, adRMask, scanImage):
137
+ # Get the middle slice of the kidney mask from where there is the first 1 value to the last 1 value
138
+ middleSlice = np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0] + int(
139
+ (
140
+ np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][-1]
141
+ - np.where(kidneyLMask.sum(axis=(0, 1)) > 0)[0][0]
142
+ )
143
+ / 2
144
+ )
145
+ # print("Middle slice: ", middleSlice)
146
+ # make middleSlice int
147
+ middleSlice = int(middleSlice)
148
+ # select one slice using simple itk
149
+ sliceImageK = selectSlice(scanImage, middleSlice)
150
+
151
+ # Get the middle slice of the addrenal mask from where there is the first 1 value to the last 1 value
152
+ middleSlice = np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0] + int(
153
+ (
154
+ np.where(adRMask.sum(axis=(0, 1)) > 0)[0][-1]
155
+ - np.where(adRMask.sum(axis=(0, 1)) > 0)[0][0]
156
+ )
157
+ / 2
158
+ )
159
+ # print("Middle slice: ", middleSlice)
160
+ # make middleSlice int
161
+ middleSlice = int(middleSlice)
162
+ # select one slice using simple itk
163
+ sliceImageA = selectSlice(scanImage, middleSlice)
164
+
165
+ sliceImageK = windowing(sliceImageK)
166
+ sliceImageA = windowing(sliceImageA)
167
+
168
+ return sliceImageK, sliceImageA
169
+
170
+
171
+ def getFeatures(TSArray, scanArray):
172
+ aortaMask = getClassBinaryMask(TSArray, 7)
173
+ IVCMask = getClassBinaryMask(TSArray, 8)
174
+ portalMask = getClassBinaryMask(TSArray, 9)
175
+ atriumMask = getClassBinaryMask(TSArray, 45)
176
+ kidneyLMask = getClassBinaryMask(TSArray, 3)
177
+ kidneyRMask = getClassBinaryMask(TSArray, 2)
178
+ adRMask = getClassBinaryMask(TSArray, 11)
179
+
180
+ # Remove toraccic aorta adn IVC from aorta and IVC masks
181
+ anteriorAtriumMask = getMaskAnteriorAtrium(atriumMask)
182
+ aortaMask = aortaMask * (anteriorAtriumMask == 0)
183
+ IVCMask = IVCMask * (anteriorAtriumMask == 0)
184
+
185
+ # Erode vessels to get only the center of the vessels
186
+ struct2 = np.ones((3, 3, 3))
187
+ aortaMaskEroded = ndi.binary_erosion(aortaMask, structure=struct2).astype(
188
+ aortaMask.dtype
189
+ )
190
+ IVCMaskEroded = ndi.binary_erosion(IVCMask, structure=struct2).astype(IVCMask.dtype)
191
+
192
+ struct3 = np.ones((1, 1, 1))
193
+ portalMaskEroded = ndi.binary_erosion(portalMask, structure=struct3).astype(
194
+ portalMask.dtype
195
+ )
196
+ # If portalMaskEroded has less then 500 values, use the original portalMask
197
+ if np.count_nonzero(portalMaskEroded) < 500:
198
+ portalMaskEroded = portalMask
199
+
200
+ # Get masked values from scan
201
+ aortaArray = keep_masked_values(scanArray, aortaMaskEroded)
202
+ IVCArray = keep_masked_values(scanArray, IVCMaskEroded)
203
+ portalArray = keep_masked_values(scanArray, portalMaskEroded)
204
+ kidneyLArray = keep_masked_values(scanArray, kidneyLMask)
205
+ kidneyRArray = keep_masked_values(scanArray, kidneyRMask)
206
+
207
+ """Put this on a separate function and return only the pelvis arrays"""
208
+ # process the Renal Pelvis masks from the Kidney masks
209
+ # create the convex hull of the Left Kidney
210
+ kidneyLHull = fill_hull(kidneyLMask)
211
+ # exclude the Left Kidney mask from the Left Convex Hull
212
+ kidneyLHull = kidneyLHull * (kidneyLMask == 0)
213
+ # erode the kidneyHull to remove the edges
214
+ struct = np.ones((3, 3, 3))
215
+ kidneyLHull = ndi.binary_erosion(kidneyLHull, structure=struct).astype(
216
+ kidneyLHull.dtype
217
+ )
218
+ # keep the values of the scanArray that are in the Left Convex Hull
219
+ pelvisLArray = keep_masked_values(scanArray, kidneyLHull)
220
+
221
+ # create the convex hull of the Right Kidney
222
+ kidneyRHull = fill_hull(kidneyRMask)
223
+ # exclude the Right Kidney mask from the Right Convex Hull
224
+ kidneyRHull = kidneyRHull * (kidneyRMask == 0)
225
+ # erode the kidneyHull to remove the edges
226
+ struct = np.ones((3, 3, 3))
227
+ kidneyRHull = ndi.binary_erosion(kidneyRHull, structure=struct).astype(
228
+ kidneyRHull.dtype
229
+ )
230
+ # keep the values of the scanArray that are in the Right Convex Hull
231
+ pelvisRArray = keep_masked_values(scanArray, kidneyRHull)
232
+
233
+ # Get the stats
234
+ # Get the stats for the aortaArray
235
+ (
236
+ aorta_max_val,
237
+ aorta_min_val,
238
+ aorta_mean_val,
239
+ aorta_median_val,
240
+ aorta_std_val,
241
+ aorta_variance_val,
242
+ ) = get_stats(aortaArray)
243
+
244
+ # Get the stats for the IVCArray
245
+ (
246
+ IVC_max_val,
247
+ IVC_min_val,
248
+ IVC_mean_val,
249
+ IVC_median_val,
250
+ IVC_std_val,
251
+ IVC_variance_val,
252
+ ) = get_stats(IVCArray)
253
+
254
+ # Get the stats for the portalArray
255
+ (
256
+ portal_max_val,
257
+ portal_min_val,
258
+ portal_mean_val,
259
+ portal_median_val,
260
+ portal_std_val,
261
+ portal_variance_val,
262
+ ) = get_stats(portalArray)
263
+
264
+ # Get the stats for the kidneyLArray and kidneyRArray
265
+ (
266
+ kidneyL_max_val,
267
+ kidneyL_min_val,
268
+ kidneyL_mean_val,
269
+ kidneyL_median_val,
270
+ kidneyL_std_val,
271
+ kidneyL_variance_val,
272
+ ) = get_stats(kidneyLArray)
273
+ (
274
+ kidneyR_max_val,
275
+ kidneyR_min_val,
276
+ kidneyR_mean_val,
277
+ kidneyR_median_val,
278
+ kidneyR_std_val,
279
+ kidneyR_variance_val,
280
+ ) = get_stats(kidneyRArray)
281
+
282
+ (
283
+ pelvisL_max_val,
284
+ pelvisL_min_val,
285
+ pelvisL_mean_val,
286
+ pelvisL_median_val,
287
+ pelvisL_std_val,
288
+ pelvisL_variance_val,
289
+ ) = get_stats(pelvisLArray)
290
+ (
291
+ pelvisR_max_val,
292
+ pelvisR_min_val,
293
+ pelvisR_mean_val,
294
+ pelvisR_median_val,
295
+ pelvisR_std_val,
296
+ pelvisR_variance_val,
297
+ ) = get_stats(pelvisRArray)
298
+
299
+ # create three new columns for the decision tree
300
+ # aorta - porta, Max min and mean columns
301
+ aorta_porta_max = aorta_max_val - portal_max_val
302
+ aorta_porta_min = aorta_min_val - portal_min_val
303
+ aorta_porta_mean = aorta_mean_val - portal_mean_val
304
+
305
+ # aorta - IVC, Max min and mean columns
306
+ aorta_IVC_max = aorta_max_val - IVC_max_val
307
+ aorta_IVC_min = aorta_min_val - IVC_min_val
308
+ aorta_IVC_mean = aorta_mean_val - IVC_mean_val
309
+
310
+ # Save stats in CSV:
311
+ # Create a list to store the stats
312
+ stats = []
313
+ # Add the stats for the aortaArray to the list
314
+ stats.extend(
315
+ [
316
+ aorta_max_val,
317
+ aorta_min_val,
318
+ aorta_mean_val,
319
+ aorta_median_val,
320
+ aorta_std_val,
321
+ aorta_variance_val,
322
+ ]
323
+ )
324
+ # Add the stats for the IVCArray to the list
325
+ stats.extend(
326
+ [
327
+ IVC_max_val,
328
+ IVC_min_val,
329
+ IVC_mean_val,
330
+ IVC_median_val,
331
+ IVC_std_val,
332
+ IVC_variance_val,
333
+ ]
334
+ )
335
+ # Add the stats for the portalArray to the list
336
+ stats.extend(
337
+ [
338
+ portal_max_val,
339
+ portal_min_val,
340
+ portal_mean_val,
341
+ portal_median_val,
342
+ portal_std_val,
343
+ portal_variance_val,
344
+ ]
345
+ )
346
+ # Add the stats for the kidneyLArray and kidneyRArray to the list
347
+ stats.extend(
348
+ [
349
+ kidneyL_max_val,
350
+ kidneyL_min_val,
351
+ kidneyL_mean_val,
352
+ kidneyL_median_val,
353
+ kidneyL_std_val,
354
+ kidneyL_variance_val,
355
+ ]
356
+ )
357
+ stats.extend(
358
+ [
359
+ kidneyR_max_val,
360
+ kidneyR_min_val,
361
+ kidneyR_mean_val,
362
+ kidneyR_median_val,
363
+ kidneyR_std_val,
364
+ kidneyR_variance_val,
365
+ ]
366
+ )
367
+ # Add the stats for the kidneyLHull and kidneyRHull to the list
368
+ stats.extend(
369
+ [
370
+ pelvisL_max_val,
371
+ pelvisL_min_val,
372
+ pelvisL_mean_val,
373
+ pelvisL_median_val,
374
+ pelvisL_std_val,
375
+ pelvisL_variance_val,
376
+ ]
377
+ )
378
+ stats.extend(
379
+ [
380
+ pelvisR_max_val,
381
+ pelvisR_min_val,
382
+ pelvisR_mean_val,
383
+ pelvisR_median_val,
384
+ pelvisR_std_val,
385
+ pelvisR_variance_val,
386
+ ]
387
+ )
388
+
389
+ stats.extend(
390
+ [
391
+ aorta_porta_max,
392
+ aorta_porta_min,
393
+ aorta_porta_mean,
394
+ aorta_IVC_max,
395
+ aorta_IVC_min,
396
+ aorta_IVC_mean,
397
+ ]
398
+ )
399
+
400
+ return stats, kidneyLMask, adRMask
401
+
402
+
403
+ def loadModel():
404
+ c2cPath = os.path.dirname(sys.path[0])
405
+ filename = os.path.join(c2cPath, "comp2comp", "contrast_phase", "xgboost.pkl")
406
+ model = pickle.load(open(filename, "rb"))
407
+
408
+ return model
409
+
410
+
411
+ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False):
412
+ TS_array, image_array = loadNiftis(TS_path, scan_path)
413
+ model = loadModel()
414
+ # TS_array, image_array = loadNiftis(TS_output_nifti_path, image_nifti_path)
415
+ featureArray, kidneyLMask, adRMask = getFeatures(TS_array, image_array)
416
+ y_pred = model.predict([featureArray])
417
+
418
+ if y_pred == 0:
419
+ pred_phase = "non-contrast"
420
+ if y_pred == 1:
421
+ pred_phase = "arterial"
422
+ if y_pred == 2:
423
+ pred_phase = "venous"
424
+ if y_pred == 3:
425
+ pred_phase = "delayed"
426
+
427
+ output_path_metrics = os.path.join(outputPath, "metrics")
428
+ if not os.path.exists(output_path_metrics):
429
+ os.makedirs(output_path_metrics)
430
+ outputTxt = os.path.join(output_path_metrics, "phase_prediction.txt")
431
+ with open(outputTxt, "w") as text_file:
432
+ text_file.write(pred_phase)
433
+ print(pred_phase)
434
+
435
+ output_path_images = os.path.join(outputPath, "images")
436
+ if not os.path.exists(output_path_images):
437
+ os.makedirs(output_path_images)
438
+ scanImage = loadNiiImageWithSitk(scan_path)
439
+ sliceImageK, sliceImageA = selectSampleSlice(kidneyLMask, adRMask, scanImage)
440
+ outJpgK = os.path.join(output_path_images, "sampleSliceKidney.png")
441
+ sitk.WriteImage(sliceImageK, outJpgK)
442
+ outJpgA = os.path.join(output_path_images, "sampleSliceAdrenal.png")
443
+ sitk.WriteImage(sliceImageA, outJpgA)
444
+
445
+
446
+ if __name__ == "__main__":
447
+ # parse arguments optional
448
+ parser = argparse.ArgumentParser()
449
+ parser.add_argument("--TS_path", type=str, required=True, help="Input image")
450
+ parser.add_argument("--scan_path", type=str, required=True, help="Input image")
451
+ parser.add_argument(
452
+ "--output_dir",
453
+ type=str,
454
+ required=False,
455
+ help="Output .txt prediction",
456
+ default=None,
457
+ )
458
+ parser.add_argument(
459
+ "--save_sample",
460
+ type=bool,
461
+ required=False,
462
+ help="Save jpeg sample ",
463
+ default=False,
464
+ )
465
+ args = parser.parse_args()
466
+ predict_phase(args.TS_path, args.scan_path, args.output_dir, args.save_sample)
Comp2Comp-main/comp2comp/contrast_phase/contrast_phase.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from time import time
4
+ from typing import Union
5
+
6
+ from totalsegmentator.libs import (
7
+ download_pretrained_weights,
8
+ nostdout,
9
+ setup_nnunet,
10
+ )
11
+
12
+ from comp2comp.contrast_phase.contrast_inf import predict_phase
13
+ from comp2comp.inference_class_base import InferenceClass
14
+
15
+
16
+ class ContrastPhaseDetection(InferenceClass):
17
+ """Contrast Phase Detection."""
18
+
19
+ def __init__(self, input_path):
20
+ super().__init__()
21
+ self.input_path = input_path
22
+
23
+ def __call__(self, inference_pipeline):
24
+ self.output_dir = inference_pipeline.output_dir
25
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
26
+ if not os.path.exists(self.output_dir_segmentations):
27
+ os.makedirs(self.output_dir_segmentations)
28
+ self.model_dir = inference_pipeline.model_dir
29
+
30
+ seg, img = self.run_segmentation(
31
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
32
+ self.output_dir_segmentations + "s01.nii.gz",
33
+ inference_pipeline.model_dir,
34
+ )
35
+
36
+ # segArray, imgArray = self.convertNibToNumpy(seg, img)
37
+
38
+ imgNiftiPath = os.path.join(
39
+ self.output_dir_segmentations, "converted_dcm.nii.gz"
40
+ )
41
+ segNiftPath = os.path.join(self.output_dir_segmentations, "s01.nii.gz")
42
+
43
+ predict_phase(segNiftPath, imgNiftiPath, outputPath=self.output_dir)
44
+
45
+ return {}
46
+
47
+ def run_segmentation(
48
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
49
+ ):
50
+ """Run segmentation.
51
+
52
+ Args:
53
+ input_path (Union[str, Path]): Input path.
54
+ output_path (Union[str, Path]): Output path.
55
+ """
56
+
57
+ print("Segmenting...")
58
+ st = time()
59
+ os.environ["SCRATCH"] = self.model_dir
60
+
61
+ # Setup nnunet
62
+ model = "3d_fullres"
63
+ folds = [0]
64
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
65
+ crop_path = None
66
+ task_id = [251]
67
+
68
+ setup_nnunet()
69
+ for task_id in [251]:
70
+ download_pretrained_weights(task_id)
71
+
72
+ from totalsegmentator.nnunet import nnUNet_predict_image
73
+
74
+ with nostdout():
75
+ img, seg = nnUNet_predict_image(
76
+ input_path,
77
+ output_path,
78
+ task_id,
79
+ model=model,
80
+ folds=folds,
81
+ trainer=trainer,
82
+ tta=False,
83
+ multilabel_image=True,
84
+ resample=1.5,
85
+ crop=None,
86
+ crop_path=crop_path,
87
+ task_name="total",
88
+ nora_tag=None,
89
+ preview=False,
90
+ nr_threads_resampling=1,
91
+ nr_threads_saving=6,
92
+ quiet=False,
93
+ verbose=False,
94
+ test=0,
95
+ )
96
+ end = time()
97
+
98
+ # Log total time for spine segmentation
99
+ print(f"Total time for segmentation: {end-st:.2f}s.")
100
+
101
+ return seg, img
102
+
103
+ def convertNibToNumpy(self, TSNib, ImageNib):
104
+ """Convert nifti to numpy array.
105
+
106
+ Args:
107
+ TSNib (nibabel.nifti1.Nifti1Image): TotalSegmentator output.
108
+ ImageNib (nibabel.nifti1.Nifti1Image): Input image.
109
+
110
+ Returns:
111
+ numpy.ndarray: TotalSegmentator output.
112
+ numpy.ndarray: Input image.
113
+ """
114
+ TS_array = TSNib.get_fdata()
115
+ img_array = ImageNib.get_fdata()
116
+ return TS_array, img_array
Comp2Comp-main/comp2comp/contrast_phase/xgboost.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:070af05754cc9541e924c0ede654b1c40a01b9240f14483af5284ae0b92d4169
3
+ size 422989
Comp2Comp-main/comp2comp/hip/hip.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from time import time
8
+ from typing import Union
9
+
10
+ import pandas as pd
11
+ from totalsegmentator.libs import (
12
+ download_pretrained_weights,
13
+ nostdout,
14
+ setup_nnunet,
15
+ )
16
+
17
+ from comp2comp.hip import hip_utils
18
+ from comp2comp.hip.hip_visualization import (
19
+ hip_report_visualizer,
20
+ hip_roi_visualizer,
21
+ )
22
+ from comp2comp.inference_class_base import InferenceClass
23
+ from comp2comp.models.models import Models
24
+
25
+
26
+ class HipSegmentation(InferenceClass):
27
+ """Spine segmentation."""
28
+
29
+ def __init__(self, model_name):
30
+ super().__init__()
31
+ self.model_name = model_name
32
+ self.model = Models.model_from_name(model_name)
33
+
34
+ def __call__(self, inference_pipeline):
35
+ # inference_pipeline.dicom_series_path = self.input_path
36
+ self.output_dir = inference_pipeline.output_dir
37
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
38
+ if not os.path.exists(self.output_dir_segmentations):
39
+ os.makedirs(self.output_dir_segmentations)
40
+
41
+ self.model_dir = inference_pipeline.model_dir
42
+
43
+ seg, mv = self.hip_seg(
44
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
45
+ self.output_dir_segmentations + "hip.nii.gz",
46
+ inference_pipeline.model_dir,
47
+ )
48
+
49
+ inference_pipeline.model = self.model
50
+ inference_pipeline.segmentation = seg
51
+ inference_pipeline.medical_volume = mv
52
+
53
+ return {}
54
+
55
+ def hip_seg(
56
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
57
+ ):
58
+ """Run spine segmentation.
59
+
60
+ Args:
61
+ input_path (Union[str, Path]): Input path.
62
+ output_path (Union[str, Path]): Output path.
63
+ """
64
+
65
+ print("Segmenting hip...")
66
+ st = time()
67
+ os.environ["SCRATCH"] = self.model_dir
68
+
69
+ # Setup nnunet
70
+ model = "3d_fullres"
71
+ folds = [0]
72
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
73
+ crop_path = None
74
+ task_id = [254]
75
+
76
+ if self.model_name == "ts_hip":
77
+ setup_nnunet()
78
+ download_pretrained_weights(task_id[0])
79
+ else:
80
+ raise ValueError("Invalid model name.")
81
+
82
+ from totalsegmentator.nnunet import nnUNet_predict_image
83
+
84
+ with nostdout():
85
+ img, seg = nnUNet_predict_image(
86
+ input_path,
87
+ output_path,
88
+ task_id,
89
+ model=model,
90
+ folds=folds,
91
+ trainer=trainer,
92
+ tta=False,
93
+ multilabel_image=True,
94
+ resample=1.5,
95
+ crop=None,
96
+ crop_path=crop_path,
97
+ task_name="total",
98
+ nora_tag=None,
99
+ preview=False,
100
+ nr_threads_resampling=1,
101
+ nr_threads_saving=6,
102
+ quiet=False,
103
+ verbose=False,
104
+ test=0,
105
+ )
106
+ end = time()
107
+
108
+ # Log total time for hip segmentation
109
+ print(f"Total time for hip segmentation: {end-st:.2f}s.")
110
+
111
+ return seg, img
112
+
113
+
114
+ class HipComputeROIs(InferenceClass):
115
+ def __init__(self, hip_model):
116
+ super().__init__()
117
+ self.hip_model_name = hip_model
118
+ self.hip_model_type = Models.model_from_name(self.hip_model_name)
119
+
120
+ def __call__(self, inference_pipeline):
121
+ segmentation = inference_pipeline.segmentation
122
+ medical_volume = inference_pipeline.medical_volume
123
+
124
+ model = inference_pipeline.model
125
+ images_folder = os.path.join(inference_pipeline.output_dir, "dev")
126
+ results_dict = hip_utils.compute_rois(
127
+ medical_volume, segmentation, model, images_folder
128
+ )
129
+ inference_pipeline.femur_results_dict = results_dict
130
+ return {}
131
+
132
+
133
+ class HipMetricsSaver(InferenceClass):
134
+ """Save metrics to a CSV file."""
135
+
136
+ def __init__(self):
137
+ super().__init__()
138
+
139
+ def __call__(self, inference_pipeline):
140
+ metrics_output_dir = os.path.join(inference_pipeline.output_dir, "metrics")
141
+ if not os.path.exists(metrics_output_dir):
142
+ os.makedirs(metrics_output_dir)
143
+ results_dict = inference_pipeline.femur_results_dict
144
+ left_head_hu = results_dict["left_head"]["hu"]
145
+ right_head_hu = results_dict["right_head"]["hu"]
146
+ left_intertrochanter_hu = results_dict["left_intertrochanter"]["hu"]
147
+ right_intertrochanter_hu = results_dict["right_intertrochanter"]["hu"]
148
+ left_neck_hu = results_dict["left_neck"]["hu"]
149
+ right_neck_hu = results_dict["right_neck"]["hu"]
150
+ # save to csv
151
+ df = pd.DataFrame(
152
+ {
153
+ "Left Head (HU)": [left_head_hu],
154
+ "Right Head (HU)": [right_head_hu],
155
+ "Left Intertrochanter (HU)": [left_intertrochanter_hu],
156
+ "Right Intertrochanter (HU)": [right_intertrochanter_hu],
157
+ "Left Neck (HU)": [left_neck_hu],
158
+ "Right Neck (HU)": [right_neck_hu],
159
+ }
160
+ )
161
+ df.to_csv(os.path.join(metrics_output_dir, "hip_metrics.csv"), index=False)
162
+ return {}
163
+
164
+
165
+ class HipVisualizer(InferenceClass):
166
+ def __init__(self):
167
+ super().__init__()
168
+
169
+ def __call__(self, inference_pipeline):
170
+ medical_volume = inference_pipeline.medical_volume
171
+
172
+ left_head_roi = inference_pipeline.femur_results_dict["left_head"]["roi"]
173
+ left_head_centroid = inference_pipeline.femur_results_dict["left_head"][
174
+ "centroid"
175
+ ]
176
+ left_head_hu = inference_pipeline.femur_results_dict["left_head"]["hu"]
177
+
178
+ left_intertrochanter_roi = inference_pipeline.femur_results_dict[
179
+ "left_intertrochanter"
180
+ ]["roi"]
181
+ left_intertrochanter_centroid = inference_pipeline.femur_results_dict[
182
+ "left_intertrochanter"
183
+ ]["centroid"]
184
+ left_intertrochanter_hu = inference_pipeline.femur_results_dict[
185
+ "left_intertrochanter"
186
+ ]["hu"]
187
+
188
+ left_neck_roi = inference_pipeline.femur_results_dict["left_neck"]["roi"]
189
+ left_neck_centroid = inference_pipeline.femur_results_dict["left_neck"][
190
+ "centroid"
191
+ ]
192
+ left_neck_hu = inference_pipeline.femur_results_dict["left_neck"]["hu"]
193
+
194
+ right_head_roi = inference_pipeline.femur_results_dict["right_head"]["roi"]
195
+ right_head_centroid = inference_pipeline.femur_results_dict["right_head"][
196
+ "centroid"
197
+ ]
198
+ right_head_hu = inference_pipeline.femur_results_dict["right_head"]["hu"]
199
+
200
+ right_intertrochanter_roi = inference_pipeline.femur_results_dict[
201
+ "right_intertrochanter"
202
+ ]["roi"]
203
+ right_intertrochanter_centroid = inference_pipeline.femur_results_dict[
204
+ "right_intertrochanter"
205
+ ]["centroid"]
206
+ right_intertrochanter_hu = inference_pipeline.femur_results_dict[
207
+ "right_intertrochanter"
208
+ ]["hu"]
209
+
210
+ right_neck_roi = inference_pipeline.femur_results_dict["right_neck"]["roi"]
211
+ right_neck_centroid = inference_pipeline.femur_results_dict["right_neck"][
212
+ "centroid"
213
+ ]
214
+ right_neck_hu = inference_pipeline.femur_results_dict["right_neck"]["hu"]
215
+
216
+ output_dir = inference_pipeline.output_dir
217
+ images_output_dir = os.path.join(output_dir, "images")
218
+ if not os.path.exists(images_output_dir):
219
+ os.makedirs(images_output_dir)
220
+ hip_roi_visualizer(
221
+ medical_volume,
222
+ left_head_roi,
223
+ left_head_centroid,
224
+ left_head_hu,
225
+ images_output_dir,
226
+ "left_head",
227
+ )
228
+ hip_roi_visualizer(
229
+ medical_volume,
230
+ left_intertrochanter_roi,
231
+ left_intertrochanter_centroid,
232
+ left_intertrochanter_hu,
233
+ images_output_dir,
234
+ "left_intertrochanter",
235
+ )
236
+ hip_roi_visualizer(
237
+ medical_volume,
238
+ left_neck_roi,
239
+ left_neck_centroid,
240
+ left_neck_hu,
241
+ images_output_dir,
242
+ "left_neck",
243
+ )
244
+ hip_roi_visualizer(
245
+ medical_volume,
246
+ right_head_roi,
247
+ right_head_centroid,
248
+ right_head_hu,
249
+ images_output_dir,
250
+ "right_head",
251
+ )
252
+ hip_roi_visualizer(
253
+ medical_volume,
254
+ right_intertrochanter_roi,
255
+ right_intertrochanter_centroid,
256
+ right_intertrochanter_hu,
257
+ images_output_dir,
258
+ "right_intertrochanter",
259
+ )
260
+ hip_roi_visualizer(
261
+ medical_volume,
262
+ right_neck_roi,
263
+ right_neck_centroid,
264
+ right_neck_hu,
265
+ images_output_dir,
266
+ "right_neck",
267
+ )
268
+ hip_report_visualizer(
269
+ medical_volume.get_fdata(),
270
+ left_head_roi + right_head_roi,
271
+ [left_head_centroid, right_head_centroid],
272
+ images_output_dir,
273
+ "head",
274
+ {
275
+ "Left Head HU": round(left_head_hu),
276
+ "Right Head HU": round(right_head_hu),
277
+ },
278
+ )
279
+ hip_report_visualizer(
280
+ medical_volume.get_fdata(),
281
+ left_intertrochanter_roi + right_intertrochanter_roi,
282
+ [left_intertrochanter_centroid, right_intertrochanter_centroid],
283
+ images_output_dir,
284
+ "intertrochanter",
285
+ {
286
+ "Left Intertrochanter HU": round(left_intertrochanter_hu),
287
+ "Right Intertrochanter HU": round(right_intertrochanter_hu),
288
+ },
289
+ )
290
+ hip_report_visualizer(
291
+ medical_volume.get_fdata(),
292
+ left_neck_roi + right_neck_roi,
293
+ [left_neck_centroid, right_neck_centroid],
294
+ images_output_dir,
295
+ "neck",
296
+ {
297
+ "Left Neck HU": round(left_neck_hu),
298
+ "Right Neck HU": round(right_neck_hu),
299
+ },
300
+ )
301
+ return {}
Comp2Comp-main/comp2comp/hip/hip_utils.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import math
6
+ import os
7
+ import shutil
8
+
9
+ import cv2
10
+ import nibabel as nib
11
+ import numpy as np
12
+ import scipy.ndimage as ndi
13
+ from scipy.ndimage import zoom
14
+ from skimage.morphology import ball, binary_erosion
15
+
16
+ from comp2comp.hip.hip_visualization import method_visualizer
17
+
18
+
19
+ def compute_rois(medical_volume, segmentation, model, output_dir, save=False):
20
+ left_femur_mask = segmentation.get_fdata() == model.categories["femur_left"]
21
+ left_femur_mask = left_femur_mask.astype(np.uint8)
22
+ right_femur_mask = segmentation.get_fdata() == model.categories["femur_right"]
23
+ right_femur_mask = right_femur_mask.astype(np.uint8)
24
+ left_head_roi, left_head_centroid, left_head_hu = get_femural_head_roi(
25
+ left_femur_mask, medical_volume, output_dir, "left_head"
26
+ )
27
+ right_head_roi, right_head_centroid, right_head_hu = get_femural_head_roi(
28
+ right_femur_mask, medical_volume, output_dir, "right_head"
29
+ )
30
+ (
31
+ left_intertrochanter_roi,
32
+ left_intertrochanter_centroid,
33
+ left_intertrochanter_hu,
34
+ ) = get_femural_head_roi(
35
+ left_femur_mask, medical_volume, output_dir, "left_intertrochanter"
36
+ )
37
+ (
38
+ right_intertrochanter_roi,
39
+ right_intertrochanter_centroid,
40
+ right_intertrochanter_hu,
41
+ ) = get_femural_head_roi(
42
+ right_femur_mask, medical_volume, output_dir, "right_intertrochanter"
43
+ )
44
+ (
45
+ left_neck_roi,
46
+ left_neck_centroid,
47
+ left_neck_hu,
48
+ ) = get_femural_neck_roi(
49
+ left_femur_mask,
50
+ medical_volume,
51
+ left_intertrochanter_roi,
52
+ left_intertrochanter_centroid,
53
+ left_head_roi,
54
+ left_head_centroid,
55
+ output_dir,
56
+ )
57
+ (
58
+ right_neck_roi,
59
+ right_neck_centroid,
60
+ right_neck_hu,
61
+ ) = get_femural_neck_roi(
62
+ right_femur_mask,
63
+ medical_volume,
64
+ right_intertrochanter_roi,
65
+ right_intertrochanter_centroid,
66
+ right_head_roi,
67
+ right_head_centroid,
68
+ output_dir,
69
+ )
70
+ combined_roi = (
71
+ left_head_roi
72
+ + (right_head_roi) # * 2)
73
+ + (left_intertrochanter_roi) # * 3)
74
+ + (right_intertrochanter_roi) # * 4)
75
+ + (left_neck_roi) # * 5)
76
+ + (right_neck_roi) # * 6)
77
+ )
78
+
79
+ if save:
80
+ # make roi directory if it doesn't exist
81
+ parent_output_dir = os.path.dirname(output_dir)
82
+ roi_output_dir = os.path.join(parent_output_dir, "rois")
83
+ if not os.path.exists(roi_output_dir):
84
+ os.makedirs(roi_output_dir)
85
+
86
+ # Convert left ROI to NIfTI
87
+ left_roi_nifti = nib.Nifti1Image(combined_roi, medical_volume.affine)
88
+ left_roi_path = os.path.join(roi_output_dir, "roi.nii.gz")
89
+ nib.save(left_roi_nifti, left_roi_path)
90
+ shutil.copy(
91
+ os.path.join(
92
+ os.path.dirname(os.path.abspath(__file__)),
93
+ "tunnelvision.ipynb",
94
+ ),
95
+ parent_output_dir,
96
+ )
97
+
98
+ return {
99
+ "left_head": {
100
+ "roi": left_head_roi,
101
+ "centroid": left_head_centroid,
102
+ "hu": left_head_hu,
103
+ },
104
+ "right_head": {
105
+ "roi": right_head_roi,
106
+ "centroid": right_head_centroid,
107
+ "hu": right_head_hu,
108
+ },
109
+ "left_intertrochanter": {
110
+ "roi": left_intertrochanter_roi,
111
+ "centroid": left_intertrochanter_centroid,
112
+ "hu": left_intertrochanter_hu,
113
+ },
114
+ "right_intertrochanter": {
115
+ "roi": right_intertrochanter_roi,
116
+ "centroid": right_intertrochanter_centroid,
117
+ "hu": right_intertrochanter_hu,
118
+ },
119
+ "left_neck": {
120
+ "roi": left_neck_roi,
121
+ "centroid": left_neck_centroid,
122
+ "hu": left_neck_hu,
123
+ },
124
+ "right_neck": {
125
+ "roi": right_neck_roi,
126
+ "centroid": right_neck_centroid,
127
+ "hu": right_neck_hu,
128
+ },
129
+ }
130
+
131
+
132
+ def get_femural_head_roi(
133
+ femur_mask,
134
+ medical_volume,
135
+ output_dir,
136
+ anatomy,
137
+ visualize_method=False,
138
+ min_pixel_count=20,
139
+ ):
140
+ top = np.where(femur_mask.sum(axis=(0, 1)) != 0)[0].max()
141
+ top_mask = femur_mask[:, :, top]
142
+
143
+ print(f"======== Computing {anatomy} femur ROIs ========")
144
+
145
+ while True:
146
+ labeled, num_features = ndi.label(top_mask)
147
+
148
+ component_sizes = np.bincount(labeled.ravel())
149
+ valid_components = np.where(component_sizes >= min_pixel_count)[0][1:]
150
+
151
+ if len(valid_components) == 2:
152
+ break
153
+
154
+ top -= 1
155
+ if top < 0:
156
+ print("Two connected components not found in the femur mask.")
157
+ break
158
+ top_mask = femur_mask[:, :, top]
159
+
160
+ if len(valid_components) == 2:
161
+ # Find the center of mass for each connected component
162
+ center_of_mass_1 = list(
163
+ ndi.center_of_mass(top_mask, labeled, valid_components[0])
164
+ )
165
+ center_of_mass_2 = list(
166
+ ndi.center_of_mass(top_mask, labeled, valid_components[1])
167
+ )
168
+
169
+ # Assign left_center_of_mass to be the center of mass with lowest value in the first dimension
170
+ if center_of_mass_1[0] < center_of_mass_2[0]:
171
+ left_center_of_mass = center_of_mass_1
172
+ right_center_of_mass = center_of_mass_2
173
+ else:
174
+ left_center_of_mass = center_of_mass_2
175
+ right_center_of_mass = center_of_mass_1
176
+
177
+ print(f"Left center of mass: {left_center_of_mass}")
178
+ print(f"Right center of mass: {right_center_of_mass}")
179
+
180
+ if anatomy == "left_intertrochanter" or anatomy == "right_head":
181
+ center_of_mass = left_center_of_mass
182
+ elif anatomy == "right_intertrochanter" or anatomy == "left_head":
183
+ center_of_mass = right_center_of_mass
184
+
185
+ coronal_slice = femur_mask[:, round(center_of_mass[1]), :]
186
+ coronal_image = medical_volume.get_fdata()[:, round(center_of_mass[1]), :]
187
+ sagittal_slice = femur_mask[round(center_of_mass[0]), :, :]
188
+ sagittal_image = medical_volume.get_fdata()[round(center_of_mass[0]), :, :]
189
+
190
+ zooms = medical_volume.header.get_zooms()
191
+ zoom_factor = zooms[2] / zooms[1]
192
+
193
+ coronal_slice = zoom(coronal_slice, (1, zoom_factor), order=1).round()
194
+ coronal_image = zoom(coronal_image, (1, zoom_factor), order=3).round()
195
+ sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round()
196
+
197
+ centroid = [round(center_of_mass[0]), 0, 0]
198
+
199
+ print(f"Starting centroid: {centroid}")
200
+
201
+ for _ in range(3):
202
+ sagittal_slice = femur_mask[centroid[0], :, :]
203
+ sagittal_slice = zoom(sagittal_slice, (1, zoom_factor), order=1).round()
204
+ centroid[1], centroid[2], radius_sagittal = inscribe_sagittal(
205
+ sagittal_slice, zoom_factor
206
+ )
207
+
208
+ print(f"Centroid after inscribe sagittal: {centroid}")
209
+
210
+ axial_slice = femur_mask[:, :, centroid[2]]
211
+ if anatomy == "left_intertrochanter" or anatomy == "right_head":
212
+ axial_slice[round(right_center_of_mass[0]) :, :] = 0
213
+ elif anatomy == "right_intertrochanter" or anatomy == "left_head":
214
+ axial_slice[: round(left_center_of_mass[0]), :] = 0
215
+ centroid[0], centroid[1], radius_axial = inscribe_axial(axial_slice)
216
+
217
+ print(f"Centroid after inscribe axial: {centroid}")
218
+
219
+ axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])]
220
+ sagittal_image = medical_volume.get_fdata()[round(centroid[0]), :, :]
221
+ sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round()
222
+
223
+ if visualize_method:
224
+ method_visualizer(
225
+ sagittal_image,
226
+ axial_image,
227
+ axial_slice,
228
+ sagittal_slice,
229
+ [centroid[2], centroid[1]],
230
+ radius_sagittal,
231
+ [centroid[1], centroid[0]],
232
+ radius_axial,
233
+ output_dir,
234
+ anatomy,
235
+ )
236
+
237
+ roi = compute_hip_roi(medical_volume, centroid, radius_sagittal, radius_axial)
238
+
239
+ # selem = ndi.generate_binary_structure(3, 1)
240
+ selem = ball(3)
241
+ femur_mask_eroded = binary_erosion(femur_mask, selem)
242
+ roi = roi * femur_mask_eroded
243
+ roi_eroded = roi.astype(np.uint8)
244
+
245
+ hu = get_mean_roi_hu(medical_volume, roi_eroded)
246
+
247
+ return (roi_eroded, centroid, hu)
248
+
249
+
250
+ def get_femural_neck_roi(
251
+ femur_mask,
252
+ medical_volume,
253
+ intertrochanter_roi,
254
+ intertrochanter_centroid,
255
+ head_roi,
256
+ head_centroid,
257
+ output_dir,
258
+ ):
259
+ zooms = medical_volume.header.get_zooms()
260
+
261
+ direction_vector = np.array(head_centroid) - np.array(intertrochanter_centroid)
262
+ unit_direction_vector = direction_vector / np.linalg.norm(direction_vector)
263
+
264
+ z, y, x = np.where(intertrochanter_roi)
265
+ intertrochanter_points = np.column_stack((z, y, x))
266
+ t_start = np.dot(
267
+ intertrochanter_points - intertrochanter_centroid, unit_direction_vector
268
+ ).max()
269
+
270
+ z, y, x = np.where(head_roi)
271
+ head_points = np.column_stack((z, y, x))
272
+ t_end = (
273
+ np.linalg.norm(direction_vector)
274
+ + np.dot(head_points - head_centroid, unit_direction_vector).min()
275
+ )
276
+
277
+ z, y, x = np.indices(femur_mask.shape)
278
+ coordinates = np.stack((z, y, x), axis=-1)
279
+
280
+ distance_to_line_origin = np.dot(
281
+ coordinates - intertrochanter_centroid, unit_direction_vector
282
+ )
283
+
284
+ coordinates_zoomed = coordinates * zooms
285
+ intertrochanter_centroid_zoomed = np.array(intertrochanter_centroid) * zooms
286
+ unit_direction_vector_zoomed = unit_direction_vector * zooms
287
+
288
+ distance_to_line = np.linalg.norm(
289
+ np.cross(
290
+ coordinates_zoomed - intertrochanter_centroid_zoomed,
291
+ coordinates_zoomed
292
+ - (intertrochanter_centroid_zoomed + unit_direction_vector_zoomed),
293
+ ),
294
+ axis=-1,
295
+ ) / np.linalg.norm(unit_direction_vector_zoomed)
296
+
297
+ cylinder_radius = 10
298
+
299
+ cylinder_mask = (
300
+ (distance_to_line <= cylinder_radius)
301
+ & (distance_to_line_origin >= t_start)
302
+ & (distance_to_line_origin <= t_end)
303
+ )
304
+
305
+ # selem = ndi.generate_binary_structure(3, 1)
306
+ selem = ball(3)
307
+ femur_mask_eroded = binary_erosion(femur_mask, selem)
308
+ roi = cylinder_mask * femur_mask_eroded
309
+ neck_roi = roi.astype(np.uint8)
310
+
311
+ hu = get_mean_roi_hu(medical_volume, neck_roi)
312
+
313
+ centroid = list(
314
+ intertrochanter_centroid + unit_direction_vector * (t_start + t_end) / 2
315
+ )
316
+ centroid = [round(x) for x in centroid]
317
+
318
+ return neck_roi, centroid, hu
319
+
320
+
321
+ def compute_hip_roi(img, centroid, radius_sagittal, radius_axial):
322
+ pixel_spacing = img.header.get_zooms()
323
+ length_i = radius_axial * 0.75 / pixel_spacing[0]
324
+ length_j = radius_axial * 0.75 / pixel_spacing[1]
325
+ length_k = radius_sagittal * 0.75 / pixel_spacing[2]
326
+
327
+ roi = np.zeros(img.get_fdata().shape, dtype=np.uint8)
328
+ i_lower = math.floor(centroid[0] - length_i)
329
+ j_lower = math.floor(centroid[1] - length_j)
330
+ k_lower = math.floor(centroid[2] - length_k)
331
+ for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1):
332
+ for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1):
333
+ for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1):
334
+ if (i - centroid[0]) ** 2 / length_i**2 + (
335
+ j - centroid[1]
336
+ ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1:
337
+ roi[i, j, k] = 1
338
+ return roi
339
+
340
+
341
+ def inscribe_axial(axial_mask):
342
+ dist_map = cv2.distanceTransform(axial_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
343
+ _, radius_axial, _, center_axial = cv2.minMaxLoc(dist_map)
344
+ center_axial = list(center_axial)
345
+ left_right_center = round(center_axial[1])
346
+ posterior_anterior_center = round(center_axial[0])
347
+ return left_right_center, posterior_anterior_center, radius_axial
348
+
349
+
350
+ def inscribe_sagittal(sagittal_mask, zoom_factor):
351
+ dist_map = cv2.distanceTransform(sagittal_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
352
+ _, radius_sagittal, _, center_sagittal = cv2.minMaxLoc(dist_map)
353
+ center_sagittal = list(center_sagittal)
354
+ posterior_anterior_center = round(center_sagittal[1])
355
+ inferior_superior_center = round(center_sagittal[0])
356
+ inferior_superior_center = round(inferior_superior_center / zoom_factor)
357
+ return posterior_anterior_center, inferior_superior_center, radius_sagittal
358
+
359
+
360
+ def get_mean_roi_hu(medical_volume, roi):
361
+ masked_medical_volume = medical_volume.get_fdata() * roi
362
+ return np.mean(masked_medical_volume[masked_medical_volume != 0])
Comp2Comp-main/comp2comp/hip/hip_visualization.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+ from scipy.ndimage import zoom
9
+
10
+ from comp2comp.visualization.detectron_visualizer import Visualizer
11
+ from comp2comp.visualization.linear_planar_reformation import (
12
+ linear_planar_reformation,
13
+ )
14
+
15
+
16
+ def method_visualizer(
17
+ sagittal_image,
18
+ axial_image,
19
+ axial_slice,
20
+ sagittal_slice,
21
+ center_sagittal,
22
+ radius_sagittal,
23
+ center_axial,
24
+ radius_axial,
25
+ output_dir,
26
+ anatomy,
27
+ ):
28
+ if not os.path.exists(output_dir):
29
+ os.makedirs(output_dir)
30
+
31
+ axial_image = np.clip(axial_image, -300, 1800)
32
+ axial_image = normalize_img(axial_image) * 255.0
33
+
34
+ sagittal_image = np.clip(sagittal_image, -300, 1800)
35
+ sagittal_image = normalize_img(sagittal_image) * 255.0
36
+
37
+ sagittal_image = sagittal_image.reshape(
38
+ (sagittal_image.shape[0], sagittal_image.shape[1], 1)
39
+ )
40
+ img_rgb = np.tile(sagittal_image, (1, 1, 3))
41
+ vis = Visualizer(img_rgb)
42
+ vis.draw_circle(
43
+ circle_coord=center_sagittal, color=[0, 1, 0], radius=radius_sagittal
44
+ )
45
+ vis.draw_binary_mask(sagittal_slice)
46
+
47
+ vis_obj = vis.get_output()
48
+ vis_obj.save(os.path.join(output_dir, f"{anatomy}_sagittal_method.png"))
49
+
50
+ axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1))
51
+ img_rgb = np.tile(axial_image, (1, 1, 3))
52
+ vis = Visualizer(img_rgb)
53
+ vis.draw_circle(circle_coord=center_axial, color=[0, 1, 0], radius=radius_axial)
54
+ vis.draw_binary_mask(axial_slice)
55
+
56
+ vis_obj = vis.get_output()
57
+ vis_obj.save(os.path.join(output_dir, f"{anatomy}_axial_method.png"))
58
+
59
+
60
+ def hip_roi_visualizer(
61
+ medical_volume,
62
+ roi,
63
+ centroid,
64
+ hu,
65
+ output_dir,
66
+ anatomy,
67
+ ):
68
+ zooms = medical_volume.header.get_zooms()
69
+ zoom_factor = zooms[2] / zooms[1]
70
+
71
+ sagittal_image = medical_volume.get_fdata()[centroid[0], :, :]
72
+ sagittal_roi = roi[centroid[0], :, :]
73
+
74
+ sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=1).round()
75
+ sagittal_roi = zoom(sagittal_roi, (1, zoom_factor), order=3).round()
76
+ sagittal_image = np.flip(sagittal_image.T)
77
+ sagittal_roi = np.flip(sagittal_roi.T)
78
+
79
+ axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])]
80
+ axial_roi = roi[:, :, round(centroid[2])]
81
+
82
+ axial_image = np.flip(axial_image.T)
83
+ axial_roi = np.flip(axial_roi.T)
84
+
85
+ _ROI_COLOR = np.array([1.000, 0.340, 0.200])
86
+
87
+ sagittal_image = np.clip(sagittal_image, -300, 1800)
88
+ sagittal_image = normalize_img(sagittal_image) * 255.0
89
+ sagittal_image = sagittal_image.reshape(
90
+ (sagittal_image.shape[0], sagittal_image.shape[1], 1)
91
+ )
92
+ img_rgb = np.tile(sagittal_image, (1, 1, 3))
93
+ vis = Visualizer(img_rgb)
94
+ vis.draw_binary_mask(
95
+ sagittal_roi,
96
+ color=_ROI_COLOR,
97
+ edge_color=_ROI_COLOR,
98
+ alpha=0.0,
99
+ area_threshold=0,
100
+ )
101
+ vis.draw_text(
102
+ text=f"Mean HU: {round(hu)}",
103
+ position=(412, 10),
104
+ color=_ROI_COLOR,
105
+ font_size=9,
106
+ horizontal_alignment="left",
107
+ )
108
+ vis_obj = vis.get_output()
109
+ vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_sagittal.png"))
110
+
111
+ """
112
+ axial_image = np.clip(axial_image, -300, 1800)
113
+ axial_image = normalize_img(axial_image) * 255.0
114
+ axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1))
115
+ img_rgb = np.tile(axial_image, (1, 1, 3))
116
+ vis = Visualizer(img_rgb)
117
+ vis.draw_binary_mask(
118
+ axial_roi, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0
119
+ )
120
+ vis.draw_text(
121
+ text=f"Mean HU: {round(hu)}",
122
+ position=(412, 10),
123
+ color=_ROI_COLOR,
124
+ font_size=9,
125
+ horizontal_alignment="left",
126
+ )
127
+ vis_obj = vis.get_output()
128
+ vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_axial.png"))
129
+ """
130
+
131
+
132
+ def hip_report_visualizer(medical_volume, roi, centroids, output_dir, anatomy, labels):
133
+ _ROI_COLOR = np.array([1.000, 0.340, 0.200])
134
+ image, mask = linear_planar_reformation(
135
+ medical_volume, roi, centroids, dimension="axial"
136
+ )
137
+ # add 3rd dim to image
138
+ image = np.flip(image.T)
139
+ mask = np.flip(mask.T)
140
+ mask[mask > 1] = 1
141
+ # mask = np.expand_dims(mask, axis=2)
142
+ image = np.expand_dims(image, axis=2)
143
+ image = np.clip(image, -300, 1800)
144
+ image = normalize_img(image) * 255.0
145
+ img_rgb = np.tile(image, (1, 1, 3))
146
+ vis = Visualizer(img_rgb)
147
+ vis.draw_binary_mask(
148
+ mask, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0
149
+ )
150
+ pos_idx = 0
151
+ for key, value in labels.items():
152
+ vis.draw_text(
153
+ text=f"{key}: {value}",
154
+ position=(310, 10 + pos_idx * 17),
155
+ color=_ROI_COLOR,
156
+ font_size=9,
157
+ horizontal_alignment="left",
158
+ )
159
+ pos_idx += 1
160
+ vis_obj = vis.get_output()
161
+ vis_obj.save(os.path.join(output_dir, f"{anatomy}_report_axial.png"))
162
+
163
+
164
+ def normalize_img(img: np.ndarray) -> np.ndarray:
165
+ """Normalize the image.
166
+ Args:
167
+ img (np.ndarray): Input image.
168
+ Returns:
169
+ np.ndarray: Normalized image.
170
+ """
171
+ return (img - img.min()) / (img.max() - img.min())
Comp2Comp-main/comp2comp/hip/tunnelvision.ipynb ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import voxel as vx\n",
10
+ "import tunnelvision as tv\n",
11
+ "import numpy as np\n",
12
+ "\n",
13
+ "mv = vx.load(\"./segmentations/converted_dcm.nii.gz\")\n",
14
+ "mv = mv.reformat((\"LR\", \"PA\", \"IS\"))\n",
15
+ "np_mv = mv.A\n",
16
+ "np_mv = np_mv.astype(np.int32)\n",
17
+ "np_mv = np.expand_dims(np_mv, axis=0)\n",
18
+ "np_mv = np.expand_dims(np_mv, axis=4)\n",
19
+ "\n",
20
+ "seg = vx.load(\"./rois/roi.nii.gz\")\n",
21
+ "np_seg = seg.A\n",
22
+ "np_seg_dim = seg.A\n",
23
+ "np_seg = np_seg.astype(np.int32)\n",
24
+ "np_seg = np.expand_dims(np_seg, axis=0)\n",
25
+ "np_seg = np.expand_dims(np_seg, axis=4)\n",
26
+ "\n",
27
+ "hip_seg = vx.load(\"./segmentations/hip.nii.gz\")\n",
28
+ "hip_seg = hip_seg.reformat((\"LR\", \"PA\", \"IS\"))\n",
29
+ "np_hip_seg = hip_seg.A.astype(int)\n",
30
+ "# set values not equal to 88 or 89 to 0\n",
31
+ "np_hip_seg[(np_hip_seg != 88) & (np_hip_seg != 89)] = 0\n",
32
+ "np_hip_seg[np_hip_seg != 0] = np_hip_seg[np_hip_seg != 0] + 4\n",
33
+ "np_hip_seg[np_seg_dim != 0] = 0\n",
34
+ "np_hip_seg = np_hip_seg.astype(np.int32)\n",
35
+ "np_hip_seg = np.expand_dims(np_hip_seg, axis=0)\n",
36
+ "np_hip_seg = np.expand_dims(np_hip_seg, axis=4)\n",
37
+ "\n",
38
+ "ax = tv.Axes(figsize=(512, 512))\n",
39
+ "ax.imshow(np_mv)\n",
40
+ "ax.imshow(np_seg, cmap=\"seg\")\n",
41
+ "ax.imshow(np_hip_seg, cmap=\"seg\")\n",
42
+ "ax.show()"
43
+ ]
44
+ }
45
+ ],
46
+ "metadata": {
47
+ "kernelspec": {
48
+ "display_name": "Python 3.8.16 ('c2c_env')",
49
+ "language": "python",
50
+ "name": "python3"
51
+ },
52
+ "language_info": {
53
+ "codemirror_mode": {
54
+ "name": "ipython",
55
+ "version": 3
56
+ },
57
+ "file_extension": ".py",
58
+ "mimetype": "text/x-python",
59
+ "name": "python",
60
+ "nbconvert_exporter": "python",
61
+ "pygments_lexer": "ipython3",
62
+ "version": "3.8.16"
63
+ },
64
+ "orig_nbformat": 4,
65
+ "vscode": {
66
+ "interpreter": {
67
+ "hash": "62fd47c2f495fb43260e4f88a1d5487d18d4c091bac4d4df4eca96cade9f1e23"
68
+ }
69
+ }
70
+ },
71
+ "nbformat": 4,
72
+ "nbformat_minor": 2
73
+ }
Comp2Comp-main/comp2comp/inference_class_base.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ from typing import Dict
6
+
7
+
8
+ class InferenceClass:
9
+ """Base class for inference classes."""
10
+
11
+ def __init__(self):
12
+ pass
13
+
14
+ def __call__(self) -> Dict:
15
+ raise NotImplementedError
16
+
17
+ def __repr__(self):
18
+ return self.__class__.__name__
Comp2Comp-main/comp2comp/inference_pipeline.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+ from typing import Dict, List
8
+
9
+ from comp2comp.inference_class_base import InferenceClass
10
+ from comp2comp.io.io import DicomLoader, NiftiSaver
11
+
12
+
13
+ class InferencePipeline(InferenceClass):
14
+ """Inference pipeline."""
15
+
16
+ def __init__(self, inference_classes: List = None, config: Dict = None):
17
+ self.config = config
18
+ # assign values from config to attributes
19
+ if self.config is not None:
20
+ for key, value in self.config.items():
21
+ setattr(self, key, value)
22
+
23
+ self.inference_classes = inference_classes
24
+
25
+ def __call__(self, inference_pipeline=None, **kwargs):
26
+ # print out the class names for each inference class
27
+ print("")
28
+ print("Inference pipeline:")
29
+ for i, inference_class in enumerate(self.inference_classes):
30
+ print(f"({i + 1}) {inference_class.__repr__()}")
31
+ print("")
32
+
33
+ print("Starting inference pipeline.\n")
34
+
35
+ if inference_pipeline:
36
+ for key, value in kwargs.items():
37
+ setattr(inference_pipeline, key, value)
38
+ else:
39
+ for key, value in kwargs.items():
40
+ setattr(self, key, value)
41
+
42
+ output = {}
43
+ for inference_class in self.inference_classes:
44
+ function_keys = set(inspect.signature(inference_class).parameters.keys())
45
+ function_keys.remove("inference_pipeline")
46
+
47
+ if "kwargs" in function_keys:
48
+ function_keys.remove("kwargs")
49
+
50
+ assert function_keys == set(
51
+ output.keys()
52
+ ), "Input to inference class, {}, does not have the correct parameters".format(
53
+ inference_class.__repr__()
54
+ )
55
+
56
+ print(
57
+ "Running {} with input keys {}".format(
58
+ inference_class.__repr__(),
59
+ inspect.signature(inference_class).parameters.keys(),
60
+ )
61
+ )
62
+
63
+ if inference_pipeline:
64
+ output = inference_class(
65
+ inference_pipeline=inference_pipeline, **output
66
+ )
67
+ else:
68
+ output = inference_class(inference_pipeline=self, **output)
69
+
70
+ # if not the last inference class, check that the output keys are correct
71
+ if inference_class != self.inference_classes[-1]:
72
+ print(
73
+ "Finished {} with output keys {}\n".format(
74
+ inference_class.__repr__(), output.keys()
75
+ )
76
+ )
77
+
78
+ print("Inference pipeline finished.\n")
79
+
80
+ return output
81
+
82
+
83
+ if __name__ == "__main__":
84
+ """Example usage of InferencePipeline."""
85
+ import argparse
86
+
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--dicom_dir", type=str, required=True)
89
+ args = parser.parse_args()
90
+
91
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs")
92
+ if not os.path.exists(output_dir):
93
+ os.mkdir(output_dir)
94
+ output_file_path = os.path.join(output_dir, "test.nii.gz")
95
+
96
+ pipeline = InferencePipeline(
97
+ [DicomLoader(args.dicom_dir), NiftiSaver()],
98
+ config={"output_dir": output_file_path},
99
+ )
100
+ pipeline()
101
+
102
+ print("Done.")
Comp2Comp-main/comp2comp/io/io.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+ import os
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Dict, Union
8
+
9
+ # import dicom2nifti
10
+ import dosma as dm
11
+ import pydicom
12
+ import SimpleITK as sitk
13
+
14
+ from comp2comp.inference_class_base import InferenceClass
15
+
16
+
17
+ class DicomLoader(InferenceClass):
18
+ """Load a single dicom series."""
19
+
20
+ def __init__(self, input_path: Union[str, Path]):
21
+ super().__init__()
22
+ self.dicom_dir = Path(input_path)
23
+ self.dr = dm.DicomReader()
24
+
25
+ def __call__(self, inference_pipeline) -> Dict:
26
+ medical_volume = self.dr.load(
27
+ self.dicom_dir, group_by=None, sort_by="InstanceNumber"
28
+ )[0]
29
+ return {"medical_volume": medical_volume}
30
+
31
+
32
+ class NiftiSaver(InferenceClass):
33
+ """Save dosma medical volume object to NIfTI file."""
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+ # self.output_dir = Path(output_path)
38
+ self.nw = dm.NiftiWriter()
39
+
40
+ def __call__(
41
+ self, inference_pipeline, medical_volume: dm.MedicalVolume
42
+ ) -> Dict[str, Path]:
43
+ nifti_file = inference_pipeline.output_dir
44
+ self.nw.write(medical_volume, nifti_file)
45
+ return {"nifti_file": nifti_file}
46
+
47
+
48
+ class DicomFinder(InferenceClass):
49
+ """Find dicom files in a directory."""
50
+
51
+ def __init__(self, input_path: Union[str, Path]) -> Dict[str, Path]:
52
+ super().__init__()
53
+ self.input_path = Path(input_path)
54
+
55
+ def __call__(self, inference_pipeline) -> Dict[str, Path]:
56
+ """Find dicom files in a directory.
57
+
58
+ Args:
59
+ inference_pipeline (InferencePipeline): Inference pipeline.
60
+
61
+ Returns:
62
+ Dict[str, Path]: Dictionary containing dicom files.
63
+ """
64
+ dicom_files = []
65
+ for file in self.input_path.glob("**/*.dcm"):
66
+ dicom_files.append(file)
67
+ inference_pipeline.dicom_file_paths = dicom_files
68
+ return {}
69
+
70
+
71
+ class DicomToNifti(InferenceClass):
72
+ """Convert dicom files to NIfTI files."""
73
+
74
+ def __init__(self, input_path: Union[str, Path], save=True):
75
+ super().__init__()
76
+ self.input_path = Path(input_path)
77
+ self.save = save
78
+
79
+ def __call__(self, inference_pipeline):
80
+ if os.path.exists(
81
+ os.path.join(
82
+ inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz"
83
+ )
84
+ ):
85
+ return {}
86
+ if hasattr(inference_pipeline, "medical_volume"):
87
+ return {}
88
+ output_dir = inference_pipeline.output_dir
89
+ segmentations_output_dir = os.path.join(output_dir, "segmentations")
90
+ os.makedirs(segmentations_output_dir, exist_ok=True)
91
+
92
+ # if self.input_path is a folder
93
+ if self.input_path.is_dir():
94
+ ds = dicom_series_to_nifti(
95
+ self.input_path,
96
+ output_file=os.path.join(
97
+ segmentations_output_dir, "converted_dcm.nii.gz"
98
+ ),
99
+ reorient_nifti=False,
100
+ )
101
+ inference_pipeline.dicom_series_path = str(self.input_path)
102
+ inference_pipeline.dicom_ds = ds
103
+ elif str(self.input_path).endswith(".nii"):
104
+ shutil.copy(
105
+ self.input_path,
106
+ os.path.join(segmentations_output_dir, "converted_dcm.nii"),
107
+ )
108
+ elif str(self.input_path).endswith(".nii.gz"):
109
+ shutil.copy(
110
+ self.input_path,
111
+ os.path.join(segmentations_output_dir, "converted_dcm.nii.gz"),
112
+ )
113
+
114
+ return {}
115
+
116
+
117
+ def series_selector(dicom_path):
118
+ ds = pydicom.filereader.dcmread(dicom_path)
119
+ image_type_list = list(ds.ImageType)
120
+ if not any("primary" in s.lower() for s in image_type_list):
121
+ raise ValueError("Not primary image type")
122
+ if not any("original" in s.lower() for s in image_type_list):
123
+ raise ValueError("Not original image type")
124
+ # if any("gsi" in s.lower() for s in image_type_list):
125
+ # raise ValueError("GSI image type")
126
+ if ds.ImageOrientationPatient != [1, 0, 0, 0, 1, 0]:
127
+ raise ValueError("Image orientation is not axial")
128
+ return ds
129
+
130
+
131
+ def dicom_series_to_nifti(input_path, output_file, reorient_nifti):
132
+ reader = sitk.ImageSeriesReader()
133
+ dicom_names = reader.GetGDCMSeriesFileNames(str(input_path))
134
+ ds = series_selector(dicom_names[0])
135
+ reader.SetFileNames(dicom_names)
136
+ image = reader.Execute()
137
+ sitk.WriteImage(image, output_file)
138
+ return ds
Comp2Comp-main/comp2comp/io/io_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+ import csv
5
+ import os
6
+
7
+ import pydicom
8
+
9
+
10
+ def find_dicom_files(input_path):
11
+ dicom_series = []
12
+ if not os.path.isdir(input_path):
13
+ dicom_series = [str(os.path.abspath(input_path))]
14
+ else:
15
+ for root, _, files in os.walk(input_path):
16
+ for file in files:
17
+ if file.endswith(".dcm") or file.endswith(".dicom"):
18
+ dicom_series.append(os.path.join(root, file))
19
+ return dicom_series
20
+
21
+
22
+ def get_dicom_paths_and_num(path):
23
+ """
24
+ Get all paths under a path that contain only dicom files.
25
+ Args:
26
+ path (str): Path to search.
27
+ Returns:
28
+ list: List of paths.
29
+ """
30
+ dicom_paths = []
31
+ for root, _, files in os.walk(path):
32
+ if len(files) > 0:
33
+ if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files):
34
+ dicom_paths.append((root, len(files)))
35
+
36
+ if len(dicom_paths) == 0:
37
+ raise ValueError("No scans were found in:\n" + path)
38
+
39
+ return dicom_paths
40
+
41
+
42
+ def get_dicom_or_nifti_paths_and_num(path):
43
+ """Get all paths under a path that contain only dicom files or a nifti file.
44
+ Args:
45
+ path (str): Path to search.
46
+
47
+ Returns:
48
+ list: List of paths.
49
+ """
50
+ if path.endswith(".nii") or path.endswith(".nii.gz"):
51
+ return [(path, 1)]
52
+ dicom_nifti_paths = []
53
+ for root, dirs, files in os.walk(path):
54
+ if len(files) > 0:
55
+ # if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files):
56
+ dicom_nifti_paths.append((root, len(files)))
57
+ # else:
58
+ # for file in files:
59
+ # if file.endswith(".nii") or file.endswith(".nii.gz"):
60
+ # num_slices = 450
61
+ # dicom_nifti_paths.append((os.path.join(root, file), num_slices))
62
+
63
+ return dicom_nifti_paths
64
+
65
+
66
+ def write_dicom_metadata_to_csv(ds, csv_filename):
67
+ with open(csv_filename, "w", newline="") as csvfile:
68
+ csvwriter = csv.writer(csvfile)
69
+ csvwriter.writerow(["Tag", "Keyword", "Value"])
70
+
71
+ for element in ds:
72
+ tag = element.tag
73
+ keyword = pydicom.datadict.keyword_for_tag(tag)
74
+ if keyword == "PixelData":
75
+ continue
76
+ value = str(element.value)
77
+ csvwriter.writerow([tag, keyword, value])
Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from time import time
4
+ from typing import Union
5
+
6
+ from totalsegmentator.libs import (
7
+ download_pretrained_weights,
8
+ nostdout,
9
+ setup_nnunet,
10
+ )
11
+
12
+ from comp2comp.inference_class_base import InferenceClass
13
+
14
+
15
+ class LiverSpleenPancreasSegmentation(InferenceClass):
16
+ """Organ segmentation."""
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ # self.input_path = input_path
21
+
22
+ def __call__(self, inference_pipeline):
23
+ # inference_pipeline.dicom_series_path = self.input_path
24
+ self.output_dir = inference_pipeline.output_dir
25
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
26
+ if not os.path.exists(self.output_dir_segmentations):
27
+ os.makedirs(self.output_dir_segmentations)
28
+
29
+ self.model_dir = inference_pipeline.model_dir
30
+
31
+ mv, seg = self.organ_seg(
32
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
33
+ self.output_dir_segmentations + "organs.nii.gz",
34
+ inference_pipeline.model_dir,
35
+ )
36
+
37
+ inference_pipeline.segmentation = seg
38
+ inference_pipeline.medical_volume = mv
39
+
40
+ return {}
41
+
42
+ def organ_seg(
43
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
44
+ ):
45
+ """Run organ segmentation.
46
+
47
+ Args:
48
+ input_path (Union[str, Path]): Input path.
49
+ output_path (Union[str, Path]): Output path.
50
+ """
51
+
52
+ print("Segmenting organs...")
53
+ st = time()
54
+ os.environ["SCRATCH"] = self.model_dir
55
+
56
+ # Setup nnunet
57
+ model = "3d_fullres"
58
+ folds = [0]
59
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
60
+ crop_path = None
61
+ task_id = [251]
62
+
63
+ setup_nnunet()
64
+ download_pretrained_weights(task_id[0])
65
+
66
+ from totalsegmentator.nnunet import nnUNet_predict_image
67
+
68
+ with nostdout():
69
+ seg, mvs = nnUNet_predict_image(
70
+ input_path,
71
+ output_path,
72
+ task_id,
73
+ model=model,
74
+ folds=folds,
75
+ trainer=trainer,
76
+ tta=False,
77
+ multilabel_image=True,
78
+ resample=1.5,
79
+ crop=None,
80
+ crop_path=crop_path,
81
+ task_name="total",
82
+ nora_tag="None",
83
+ preview=False,
84
+ nr_threads_resampling=1,
85
+ nr_threads_saving=6,
86
+ quiet=False,
87
+ verbose=True,
88
+ test=0,
89
+ )
90
+ end = time()
91
+
92
+ # Log total time for spine segmentation
93
+ print(f"Total time for organ segmentation: {end-st:.2f}s.")
94
+
95
+ return seg, mvs
Comp2Comp-main/comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+
6
+ import numpy as np
7
+
8
+ from comp2comp.inference_class_base import InferenceClass
9
+ from comp2comp.liver_spleen_pancreas.visualization_utils import (
10
+ generate_liver_spleen_pancreas_report,
11
+ generate_slice_images,
12
+ )
13
+
14
+
15
+ class LiverSpleenPancreasVisualizer(InferenceClass):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ self.unit_dict = {
20
+ "Volume": r"$\mathregular{cm^3}$",
21
+ "Mean": "HU",
22
+ "Median": "HU",
23
+ }
24
+
25
+ self.class_nums = [1, 5, 10]
26
+ self.organ_names = ["liver", "spleen", "pancreas"]
27
+
28
+ def __call__(self, inference_pipeline):
29
+ self.output_dir = inference_pipeline.output_dir
30
+ self.output_dir_images_organs = os.path.join(self.output_dir, "images/")
31
+ inference_pipeline.output_dir_images_organs_organs_organs = (
32
+ self.output_dir_images_organs
33
+ )
34
+
35
+ if not os.path.exists(self.output_dir_images_organs):
36
+ os.makedirs(self.output_dir_images_organs)
37
+
38
+ inference_pipeline.medical_volume_arr = np.flip(
39
+ inference_pipeline.medical_volume.get_fdata(), axis=1
40
+ )
41
+ inference_pipeline.segmentation_arr = np.flip(
42
+ inference_pipeline.segmentation.get_fdata(), axis=1
43
+ )
44
+
45
+ inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[
46
+ "pixdim"
47
+ ][1:4]
48
+ inference_pipeline.vol_per_pixel = np.prod(
49
+ inference_pipeline.pix_dims / 10
50
+ ) # mm to cm for having ml/pixel.
51
+
52
+ self.organ_metrics = generate_slice_images(
53
+ inference_pipeline.medical_volume_arr,
54
+ inference_pipeline.segmentation_arr,
55
+ self.class_nums,
56
+ self.unit_dict,
57
+ inference_pipeline.vol_per_pixel,
58
+ inference_pipeline.pix_dims,
59
+ self.output_dir_images_organs,
60
+ fontsize=24,
61
+ )
62
+
63
+ inference_pipeline.organ_metrics = self.organ_metrics
64
+
65
+ generate_liver_spleen_pancreas_report(
66
+ self.output_dir_images_organs, self.organ_names
67
+ )
68
+
69
+ return {}
70
+
71
+
72
+ class LiverSpleenPancreasMetricsPrinter(InferenceClass):
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def __call__(self, inference_pipeline):
77
+ results = inference_pipeline.organ_metrics
78
+ organs = list(results.keys())
79
+
80
+ name_dist = max([len(o) for o in organs])
81
+ metrics = []
82
+ for k in results[list(results.keys())[0]].keys():
83
+ if k != "Organ":
84
+ metrics.append(k)
85
+
86
+ units = ["cm^3", "HU", "HU"]
87
+
88
+ header = (
89
+ "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + "}") * len(metrics)
90
+ )
91
+ header = header.format(
92
+ "Organ", *[m + "(" + u + ")" for m, u in zip(metrics, units)]
93
+ )
94
+
95
+ base_print = (
96
+ "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + ".0f}") * len(metrics)
97
+ )
98
+
99
+ print("\n")
100
+ print(header)
101
+
102
+ for organ in results.values():
103
+ line = base_print.format(*organ.values())
104
+ print(line)
105
+
106
+ print("\n")
107
+
108
+ output_dir = inference_pipeline.output_dir
109
+ self.output_dir_metrics_organs = os.path.join(output_dir, "metrics/")
110
+
111
+ if not os.path.exists(self.output_dir_metrics_organs):
112
+ os.makedirs(self.output_dir_metrics_organs)
113
+
114
+ header = (
115
+ ",".join(["Organ"] + [m + "(" + u + ")" for m, u in zip(metrics, units)])
116
+ + "\n"
117
+ )
118
+ with open(
119
+ os.path.join(
120
+ self.output_dir_metrics_organs, "liver_spleen_pancreas_metrics.csv"
121
+ ),
122
+ "w",
123
+ ) as f:
124
+ f.write(header)
125
+
126
+ for organ in results.values():
127
+ line = ",".join([str(v) for v in organ.values()]) + "\n"
128
+ f.write(line)
129
+
130
+ return {}
Comp2Comp-main/comp2comp/liver_spleen_pancreas/visualization_utils.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import scipy
9
+ from matplotlib.colors import ListedColormap
10
+ from PIL import Image
11
+
12
+
13
+ def extract_axial_mid_slice(ct, mask, crop=True):
14
+ slice_idx = np.argmax(mask.sum(axis=(0, 1)))
15
+
16
+ ct_slice_z = np.transpose(ct[:, :, slice_idx], axes=(1, 0))
17
+ mask_slice_z = np.transpose(mask[:, :, slice_idx], axes=(1, 0))
18
+
19
+ ct_slice_z = np.flip(ct_slice_z, axis=(0, 1))
20
+ mask_slice_z = np.flip(mask_slice_z, axis=(0, 1))
21
+
22
+ if crop:
23
+ ct_range_x = np.where(ct_slice_z.max(axis=0) > -200)[0][[0, -1]]
24
+
25
+ ct_slice_z = ct_slice_z[
26
+ ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1]
27
+ ]
28
+ mask_slice_z = mask_slice_z[
29
+ ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1]
30
+ ]
31
+
32
+ return ct_slice_z, mask_slice_z
33
+
34
+
35
+ def extract_coronal_mid_slice(ct, mask, crop=True):
36
+ # find the slice with max coherent extent of the organ
37
+ coronary_extent = np.where(mask.sum(axis=(0, 2)))[0]
38
+
39
+ max_extent = 0
40
+ max_extent_idx = 0
41
+
42
+ for idx in coronary_extent:
43
+ label, num_features = scipy.ndimage.label(mask[:, idx, :])
44
+
45
+ if num_features > 1:
46
+ continue
47
+ else:
48
+ extent = len(np.where(label.sum(axis=1))[0])
49
+ if extent > max_extent:
50
+ max_extent = extent
51
+ max_extent_idx = idx
52
+
53
+ ct_slice_y = np.transpose(ct[:, max_extent_idx, :], axes=(1, 0))
54
+ mask_slice_y = np.transpose(mask[:, max_extent_idx, :], axes=(1, 0))
55
+
56
+ ct_slice_y = np.flip(ct_slice_y, axis=1)
57
+ mask_slice_y = np.flip(mask_slice_y, axis=1)
58
+
59
+ return ct_slice_y, mask_slice_y
60
+
61
+
62
+ def save_slice(
63
+ ct_slice,
64
+ mask_slice,
65
+ path,
66
+ figsize=(12, 12),
67
+ corner_text=None,
68
+ unit_dict=None,
69
+ aspect=1,
70
+ show=False,
71
+ xy_placement=None,
72
+ class_color=1,
73
+ fontsize=14,
74
+ ):
75
+ # colormap for shown segmentations
76
+ color_array = plt.get_cmap("tab10")(range(10))
77
+ color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:, :]), axis=0)
78
+ map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array)
79
+
80
+ fig, axx = plt.subplots(1, figsize=figsize, frameon=False)
81
+ axx.imshow(
82
+ ct_slice,
83
+ cmap="gray",
84
+ vmin=-400,
85
+ vmax=400,
86
+ interpolation="spline36",
87
+ aspect=aspect,
88
+ origin="lower",
89
+ )
90
+ axx.imshow(
91
+ mask_slice * class_color,
92
+ cmap=map_object_seg,
93
+ vmin=0,
94
+ vmax=9,
95
+ alpha=0.2,
96
+ interpolation="nearest",
97
+ aspect=aspect,
98
+ origin="lower",
99
+ )
100
+
101
+ plt.axis("off")
102
+ axx.axes.get_xaxis().set_visible(False)
103
+ axx.axes.get_yaxis().set_visible(False)
104
+
105
+ y_size, x_size = ct_slice.shape
106
+
107
+ if corner_text is not None:
108
+ bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
109
+
110
+ texts = []
111
+ for k, v in corner_text.items():
112
+ if isinstance(v, str):
113
+ texts.append("{:<9}{}".format(k + ":", v))
114
+ else:
115
+ unit = unit_dict[k] if k in unit_dict else ""
116
+ texts.append("{:<9}{:.0f} {}".format(k + ":", v, unit))
117
+
118
+ if xy_placement is None:
119
+ # get the extent of textbox, remove, and the plot again with correct position
120
+ t = axx.text(
121
+ 0.5,
122
+ 0.5,
123
+ "\n".join(texts),
124
+ color="white",
125
+ transform=axx.transAxes,
126
+ fontsize=fontsize,
127
+ family="monospace",
128
+ bbox=bbox_props,
129
+ va="top",
130
+ ha="left",
131
+ )
132
+ xmin, xmax = t.get_window_extent().xmin, t.get_window_extent().xmax
133
+ xmin, xmax = axx.transAxes.inverted().transform((xmin, xmax))
134
+
135
+ xy_placement = [1 - (xmax - xmin) - (xmax - xmin) * 0.09, 0.975]
136
+ t.remove()
137
+
138
+ axx.text(
139
+ xy_placement[0],
140
+ xy_placement[1],
141
+ "\n".join(texts),
142
+ color="white",
143
+ transform=axx.transAxes,
144
+ fontsize=fontsize,
145
+ family="monospace",
146
+ bbox=bbox_props,
147
+ va="top",
148
+ ha="left",
149
+ )
150
+
151
+ if show:
152
+ plt.show()
153
+ else:
154
+ fig.savefig(path, bbox_inches="tight", pad_inches=0)
155
+ plt.close(fig)
156
+
157
+
158
+ def slicedDilationOrErosion(input_mask, num_iteration, operation):
159
+ """
160
+ Perform the dilation on the smallest slice that will fit the
161
+ segmentation
162
+ """
163
+ margin = 2 if num_iteration is None else num_iteration + 1
164
+
165
+ # find the minimum volume enclosing the organ
166
+ x_idx = np.where(input_mask.sum(axis=(1, 2)))[0]
167
+ x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin
168
+ y_idx = np.where(input_mask.sum(axis=(0, 2)))[0]
169
+ y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin
170
+ z_idx = np.where(input_mask.sum(axis=(0, 1)))[0]
171
+ z_start, z_end = z_idx[0] - margin, z_idx[-1] + margin
172
+
173
+ struct = scipy.ndimage.generate_binary_structure(3, 1)
174
+ struct = scipy.ndimage.iterate_structure(struct, num_iteration)
175
+
176
+ if operation == "dilate":
177
+ mask_slice = scipy.ndimage.binary_dilation(
178
+ input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct
179
+ ).astype(np.int8)
180
+ elif operation == "erode":
181
+ mask_slice = scipy.ndimage.binary_erosion(
182
+ input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct
183
+ ).astype(np.int8)
184
+
185
+ output_mask = input_mask.copy()
186
+
187
+ output_mask[x_start:x_end, y_start:y_end, z_start:z_end] = mask_slice
188
+
189
+ return output_mask
190
+
191
+
192
+ def extract_organ_metrics(
193
+ ct, all_masks, class_num=None, vol_per_pixel=None, erode_mask=True
194
+ ):
195
+ if erode_mask:
196
+ eroded_mask = slicedDilationOrErosion(
197
+ input_mask=(all_masks == class_num), num_iteration=3, operation="erode"
198
+ )
199
+ ct_organ_vals = ct[eroded_mask == 1]
200
+ else:
201
+ ct_organ_vals = ct[all_masks == class_num]
202
+
203
+ results = {}
204
+
205
+ # in ml
206
+ organ_vol = (all_masks == class_num).sum() * vol_per_pixel
207
+ organ_mean = ct_organ_vals.mean()
208
+ organ_median = np.median(ct_organ_vals)
209
+
210
+ results = {
211
+ "Organ": class_map_part_organs[class_num],
212
+ "Volume": organ_vol,
213
+ "Mean": organ_mean,
214
+ "Median": organ_median,
215
+ }
216
+
217
+ return results
218
+
219
+
220
+ def generate_slice_images(
221
+ ct,
222
+ all_masks,
223
+ class_nums,
224
+ unit_dict,
225
+ vol_per_pixel,
226
+ pix_dims,
227
+ root,
228
+ fontsize=20,
229
+ show=False,
230
+ ):
231
+ all_results = {}
232
+
233
+ colors = [1, 3, 4]
234
+
235
+ for i, c_num in enumerate(class_nums):
236
+ organ_name = class_map_part_organs[c_num]
237
+
238
+ axial_path = os.path.join(root, organ_name.lower() + "_axial.png")
239
+ coronal_path = os.path.join(root, organ_name.lower() + "_coronal.png")
240
+
241
+ ct_slice_z, liver_slice_z = extract_axial_mid_slice(ct, all_masks == c_num)
242
+ results = extract_organ_metrics(
243
+ ct, all_masks, class_num=c_num, vol_per_pixel=vol_per_pixel
244
+ )
245
+
246
+ save_slice(
247
+ ct_slice_z,
248
+ liver_slice_z,
249
+ axial_path,
250
+ figsize=(12, 12),
251
+ corner_text=results,
252
+ unit_dict=unit_dict,
253
+ class_color=colors[i],
254
+ fontsize=fontsize,
255
+ show=show,
256
+ )
257
+
258
+ ct_slice_y, liver_slice_y = extract_coronal_mid_slice(ct, all_masks == c_num)
259
+
260
+ save_slice(
261
+ ct_slice_y,
262
+ liver_slice_y,
263
+ coronal_path,
264
+ figsize=(12, 12),
265
+ aspect=pix_dims[2] / pix_dims[1],
266
+ show=show,
267
+ class_color=colors[i],
268
+ )
269
+
270
+ all_results[results["Organ"]] = results
271
+
272
+ if show:
273
+ return
274
+
275
+ return all_results
276
+
277
+
278
+ def generate_liver_spleen_pancreas_report(root, organ_names):
279
+ axial_imgs = [
280
+ Image.open(os.path.join(root, organ + "_axial.png")) for organ in organ_names
281
+ ]
282
+ coronal_imgs = [
283
+ Image.open(os.path.join(root, organ + "_coronal.png")) for organ in organ_names
284
+ ]
285
+
286
+ result_width = max(
287
+ sum([img.size[0] for img in axial_imgs]),
288
+ sum([img.size[0] for img in coronal_imgs]),
289
+ )
290
+ result_height = max(
291
+ [a.size[1] + c.size[1] for a, c in zip(axial_imgs, coronal_imgs)]
292
+ )
293
+
294
+ result = Image.new("RGB", (result_width, result_height))
295
+
296
+ total_width = 0
297
+
298
+ for a_img, c_img in zip(axial_imgs, coronal_imgs):
299
+ a_width, a_height = a_img.size
300
+ c_width, c_height = c_img.size
301
+
302
+ translate = (a_width - c_width) // 2 if a_width > c_width else 0
303
+
304
+ result.paste(im=a_img, box=(total_width, 0))
305
+ result.paste(im=c_img, box=(translate + total_width, a_height))
306
+
307
+ total_width += a_width
308
+
309
+ result.save(os.path.join(root, "liver_spleen_pancreas_report.png"))
310
+
311
+
312
+ # from https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/map_to_binary.py
313
+
314
+ class_map_part_organs = {
315
+ 1: "Spleen",
316
+ 2: "Right Kidney",
317
+ 3: "Left Kidney",
318
+ 4: "Gallbladder",
319
+ 5: "Liver",
320
+ 6: "Stomach",
321
+ 7: "Aorta",
322
+ 8: "Inferior vena cava",
323
+ 9: "portal Vein and Splenic Vein",
324
+ 10: "Pancreas",
325
+ 11: "Right Adrenal Gland",
326
+ 12: "Left Adrenal Gland Left",
327
+ 13: "lung_upper_lobe_left",
328
+ 14: "lung_lower_lobe_left",
329
+ 15: "lung_upper_lobe_right",
330
+ 16: "lung_middle_lobe_right",
331
+ 17: "lung_lower_lobe_right",
332
+ }
Comp2Comp-main/comp2comp/metrics/metrics.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Callable, Sequence, Union
3
+
4
+ import numpy as np
5
+
6
+
7
+ def flatten_non_category_dims(
8
+ xs: Union[np.ndarray, Sequence[np.ndarray]], category_dim: int = None
9
+ ):
10
+ """Flattens all non-category dimensions into a single dimension.
11
+
12
+ Args:
13
+ xs (ndarrays): Sequence of ndarrays with the same category dimension.
14
+ category_dim: The dimension/axis corresponding to different categories.
15
+ i.e. `C`. If `None`, behaves like `np.flatten(x)`.
16
+
17
+ Returns:
18
+ ndarray: Shape (C, -1) if `category_dim` specified else shape (-1,)
19
+ """
20
+ single_item = isinstance(xs, np.ndarray)
21
+ if single_item:
22
+ xs = [xs]
23
+
24
+ if category_dim is not None:
25
+ dims = (xs[0].shape[category_dim], -1)
26
+ xs = (np.moveaxis(x, category_dim, 0).reshape(dims) for x in xs)
27
+ else:
28
+ xs = (x.flatten() for x in xs)
29
+
30
+ if single_item:
31
+ return list(xs)[0]
32
+ else:
33
+ return xs
34
+
35
+
36
+ class Metric(Callable, ABC):
37
+ """Interface for new metrics.
38
+
39
+ A metric should be implemented as a callable with explicitly defined
40
+ arguments. In other words, metrics should not have `**kwargs` or `**args`
41
+ options in the `__call__` method.
42
+
43
+ While not explicitly constrained to the return type, metrics typically
44
+ return float value(s). The number of values returned corresponds to the
45
+ number of categories.
46
+
47
+ * metrics should have different name() for different functionality.
48
+ * `category_dim` duck type if metric can process multiple categories at
49
+ once.
50
+
51
+ To compute metrics:
52
+
53
+ .. code-block:: python
54
+
55
+ metric = Metric()
56
+ results = metric(...)
57
+ """
58
+
59
+ def __init__(self, units: str = ""):
60
+ self.units = units
61
+
62
+ def name(self):
63
+ return type(self).__name__
64
+
65
+ def display_name(self):
66
+ """Name to use for pretty printing and display purposes."""
67
+ name = self.name()
68
+ return "{} {}".format(name, self.units) if self.units else name
69
+
70
+ @abstractmethod
71
+ def __call__(self, *args, **kwargs):
72
+ pass
73
+
74
+
75
+ class HounsfieldUnits(Metric):
76
+ FULL_NAME = "Hounsfield Unit"
77
+
78
+ def __init__(self, units="hu"):
79
+ super().__init__(units)
80
+
81
+ def __call__(self, mask, x, category_dim: int = None):
82
+ mask = mask.astype(np.bool)
83
+ if category_dim is None:
84
+ return np.mean(x[mask])
85
+
86
+ assert category_dim == -1
87
+ num_classes = mask.shape[-1]
88
+
89
+ return np.array([np.mean(x[mask[..., c]]) for c in range(num_classes)])
90
+
91
+ def name(self):
92
+ return self.FULL_NAME
93
+
94
+
95
+ class CrossSectionalArea(Metric):
96
+ def __call__(self, mask, spacing=None, category_dim: int = None):
97
+ pixel_area = np.prod(spacing) if spacing else 1
98
+ mask = mask.astype(np.bool)
99
+ mask = flatten_non_category_dims(mask, category_dim)
100
+
101
+ return pixel_area * np.count_nonzero(mask, -1) / 100.0
102
+
103
+ def name(self):
104
+ if self.units:
105
+ return "Cross-sectional Area ({})".format(self.units)
106
+ else:
107
+ return "Cross-sectional Area"
108
+
109
+
110
+ def manifest_to_map(manifest, model_type):
111
+ """Converts a manifest to a map of metric name to metric instance.
112
+
113
+ Args:
114
+ manifest (dict): A dictionary of metric name to metric instance.
115
+
116
+ Returns:
117
+ dict: A dictionary of metric name to metric instance.
118
+ """
119
+ # TODO: hacky. Update this
120
+ figure_text_key = {}
121
+ for manifest_dict in manifest:
122
+ try:
123
+ key = manifest_dict["Level"]
124
+ except BaseException:
125
+ key = ".".join((manifest_dict["File"].split("/")[-1]).split(".")[:-1])
126
+ muscle_hu = f"{manifest_dict['Hounsfield Unit (muscle)']:.2f}"
127
+ muscle_area = f"{manifest_dict['Cross-sectional Area (cm^2) (muscle)']:.2f}"
128
+ vat_hu = f"{manifest_dict['Hounsfield Unit (vat)']:.2f}"
129
+ vat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (vat)']:.2f}"
130
+ sat_hu = f"{manifest_dict['Hounsfield Unit (sat)']:.2f}"
131
+ sat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (sat)']:.2f}"
132
+ imat_hu = f"{manifest_dict['Hounsfield Unit (imat)']:.2f}"
133
+ imat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (imat)']:.2f}"
134
+ if model_type.model_name == "abCT_v0.0.1":
135
+ figure_text_key[key] = [
136
+ muscle_hu,
137
+ muscle_area,
138
+ imat_hu,
139
+ imat_area,
140
+ vat_hu,
141
+ vat_area,
142
+ sat_hu,
143
+ sat_area,
144
+ ]
145
+ else:
146
+ figure_text_key[key] = [
147
+ muscle_hu,
148
+ muscle_area,
149
+ vat_hu,
150
+ vat_area,
151
+ sat_hu,
152
+ sat_area,
153
+ imat_hu,
154
+ imat_area,
155
+ ]
156
+ return figure_text_key
Comp2Comp-main/comp2comp/models/models.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, Sequence
5
+
6
+ import wget
7
+ from keras.models import load_model
8
+
9
+
10
+ class Models(enum.Enum):
11
+ ABCT_V_0_0_1 = (
12
+ 1,
13
+ "abCT_v0.0.1",
14
+ {"muscle": 0, "imat": 1, "vat": 2, "sat": 3},
15
+ False,
16
+ ("soft", "bone", "custom"),
17
+ )
18
+
19
+ STANFORD_V_0_0_1 = (
20
+ 2,
21
+ "stanford_v0.0.1",
22
+ # ("background", "muscle", "bone", "vat", "sat", "imat"),
23
+ # Category name mapped to channel index
24
+ {"muscle": 1, "vat": 3, "sat": 4, "imat": 5},
25
+ True,
26
+ ("soft", "bone", "custom"),
27
+ )
28
+
29
+ STANFORD_V_0_0_2 = (
30
+ 3,
31
+ "stanford_v0.0.2",
32
+ {"muscle": 4, "sat": 1, "vat": 2, "imat": 3},
33
+ True,
34
+ ("soft", "bone", "custom"),
35
+ )
36
+ TS_SPINE_FULL = (
37
+ 4,
38
+ "ts_spine_full",
39
+ # Category name mapped to channel index
40
+ {
41
+ "L5": 18,
42
+ "L4": 19,
43
+ "L3": 20,
44
+ "L2": 21,
45
+ "L1": 22,
46
+ "T12": 23,
47
+ "T11": 24,
48
+ "T10": 25,
49
+ "T9": 26,
50
+ "T8": 27,
51
+ "T7": 28,
52
+ "T6": 29,
53
+ "T5": 30,
54
+ "T4": 31,
55
+ "T3": 32,
56
+ "T2": 33,
57
+ "T1": 34,
58
+ "C7": 35,
59
+ "C6": 36,
60
+ "C5": 37,
61
+ "C4": 38,
62
+ "C3": 39,
63
+ "C2": 40,
64
+ "C1": 41,
65
+ },
66
+ False,
67
+ (),
68
+ )
69
+ TS_SPINE = (
70
+ 5,
71
+ "ts_spine",
72
+ # Category name mapped to channel index
73
+ # {"L5": 18, "L4": 19, "L3": 20, "L2": 21, "L1": 22, "T12": 23},
74
+ {"L5": 27, "L4": 28, "L3": 29, "L2": 30, "L1": 31, "T12": 32},
75
+ False,
76
+ (),
77
+ )
78
+ STANFORD_SPINE_V_0_0_1 = (
79
+ 6,
80
+ "stanford_spine_v0.0.1",
81
+ # Category name mapped to channel index
82
+ {"L5": 24, "L4": 23, "L3": 22, "L2": 21, "L1": 20, "T12": 19},
83
+ False,
84
+ (),
85
+ )
86
+ TS_HIP = (
87
+ 7,
88
+ "ts_hip",
89
+ # Category name mapped to channel index
90
+ {"femur_left": 88, "femur_right": 89},
91
+ False,
92
+ (),
93
+ )
94
+
95
+ def __new__(
96
+ cls,
97
+ value: int,
98
+ model_name: str,
99
+ categories: Dict[str, int],
100
+ use_softmax: bool,
101
+ windows: Sequence[str],
102
+ ):
103
+ obj = object.__new__(cls)
104
+ obj._value_ = value
105
+
106
+ obj.model_name = model_name
107
+ obj.categories = categories
108
+ obj.use_softmax = use_softmax
109
+ obj.windows = windows
110
+ return obj
111
+
112
+ def load_model(self, model_dir):
113
+ """Load the model from the models directory.
114
+
115
+ Args:
116
+ logger (logging.Logger): Logger.
117
+
118
+ Returns:
119
+ keras.models.Model: Model.
120
+ """
121
+ try:
122
+ filename = Models.find_model_weights(self.model_name, model_dir)
123
+ except Exception:
124
+ print("Downloading muscle/fat model from hugging face")
125
+ Path(model_dir).mkdir(parents=True, exist_ok=True)
126
+ wget.download(
127
+ f"https://huggingface.co/stanfordmimi/stanford_abct_v0.0.1/resolve/main/{self.model_name}.h5",
128
+ out=os.path.join(model_dir, f"{self.model_name}.h5"),
129
+ )
130
+ filename = Models.find_model_weights(self.model_name, model_dir)
131
+ print("")
132
+
133
+ print("Loading muscle/fat model from {}".format(filename))
134
+ return load_model(filename)
135
+
136
+ @staticmethod
137
+ def model_from_name(model_name):
138
+ """Get the model enum from the model name.
139
+
140
+ Args:
141
+ model_name (str): Model name.
142
+
143
+ Returns:
144
+ Models: Model enum.
145
+ """
146
+ for model in Models:
147
+ if model.model_name == model_name:
148
+ return model
149
+ return None
150
+
151
+ @staticmethod
152
+ def find_model_weights(file_name, model_dir):
153
+ for root, _, files in os.walk(model_dir):
154
+ for file in files:
155
+ if file.startswith(file_name):
156
+ filename = os.path.join(root, file)
157
+ return filename
Comp2Comp-main/comp2comp/muscle_adipose_tissue/data.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Sequence
3
+
4
+ import keras.utils as k_utils
5
+ import numpy as np
6
+ import pydicom
7
+ from keras.utils.data_utils import OrderedEnqueuer
8
+ from tqdm import tqdm
9
+
10
+
11
+ def parse_windows(windows):
12
+ """Parse windows provided by the user.
13
+
14
+ These windows can either be strings corresponding to popular windowing
15
+ thresholds for CT or tuples of (upper, lower) bounds.
16
+
17
+ Args:
18
+ windows (list): List of strings or tuples.
19
+
20
+ Returns:
21
+ list: List of tuples of (upper, lower) bounds.
22
+ """
23
+ windowing = {
24
+ "soft": (400, 50),
25
+ "bone": (1800, 400),
26
+ "liver": (150, 30),
27
+ "spine": (250, 50),
28
+ "custom": (500, 50),
29
+ }
30
+ vals = []
31
+ for w in windows:
32
+ if isinstance(w, Sequence) and len(w) == 2:
33
+ assert_msg = "Expected tuple of (lower, upper) bound"
34
+ assert len(w) == 2, assert_msg
35
+ assert isinstance(w[0], (float, int)), assert_msg
36
+ assert isinstance(w[1], (float, int)), assert_msg
37
+ assert w[0] < w[1], assert_msg
38
+ vals.append(w)
39
+ continue
40
+
41
+ if w not in windowing:
42
+ raise KeyError("Window {} not found".format(w))
43
+ window_width = windowing[w][0]
44
+ window_level = windowing[w][1]
45
+ upper = window_level + window_width / 2
46
+ lower = window_level - window_width / 2
47
+
48
+ vals.append((lower, upper))
49
+
50
+ return tuple(vals)
51
+
52
+
53
+ def _window(xs, bounds):
54
+ """Apply windowing to an array of CT images.
55
+
56
+ Args:
57
+ xs (ndarray): NxHxW
58
+ bounds (tuple): (lower, upper) bounds
59
+
60
+ Returns:
61
+ ndarray: Windowed images.
62
+ """
63
+
64
+ imgs = []
65
+ for lb, ub in bounds:
66
+ imgs.append(np.clip(xs, a_min=lb, a_max=ub))
67
+
68
+ if len(imgs) == 1:
69
+ return imgs[0]
70
+ elif xs.shape[-1] == 1:
71
+ return np.concatenate(imgs, axis=-1)
72
+ else:
73
+ return np.stack(imgs, axis=-1)
74
+
75
+
76
+ class Dataset(k_utils.Sequence):
77
+ def __init__(self, files: List[str], batch_size: int = 16, windows=None):
78
+ self._files = files
79
+ self._batch_size = batch_size
80
+ self.windows = windows
81
+
82
+ def __len__(self):
83
+ return math.ceil(len(self._files) / self._batch_size)
84
+
85
+ def __getitem__(self, idx):
86
+ files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size]
87
+ dcms = [pydicom.read_file(f, force=True) for f in files]
88
+
89
+ xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms]
90
+
91
+ params = [
92
+ {"spacing": header.PixelSpacing, "image": x} for header, x in zip(dcms, xs)
93
+ ]
94
+
95
+ # Preprocess xs via windowing.
96
+ xs = np.stack(xs, axis=0)
97
+ if self.windows:
98
+ xs = _window(xs, parse_windows(self.windows))
99
+ else:
100
+ xs = xs[..., np.newaxis]
101
+
102
+ return xs, params
103
+
104
+
105
+ def _swap_muscle_imap(xs, ys, muscle_idx: int, imat_idx: int, threshold=-30.0):
106
+ """
107
+ If pixel labeled as muscle but has HU < threshold, change label to imat.
108
+
109
+ Args:
110
+ xs (ndarray): NxHxWxC
111
+ ys (ndarray): NxHxWxC
112
+ muscle_idx (int): Index of the muscle label.
113
+ imat_idx (int): Index of the imat label.
114
+ threshold (float): Threshold for HU value.
115
+
116
+ Returns:
117
+ ndarray: Segmentation mask with swapped labels.
118
+ """
119
+ labels = ys.copy()
120
+
121
+ muscle_mask = (labels[..., muscle_idx] > 0.5).astype(int)
122
+ imat_mask = labels[..., imat_idx]
123
+
124
+ imat_mask[muscle_mask.astype(np.bool) & (xs < threshold)] = 1
125
+ muscle_mask[xs < threshold] = 0
126
+
127
+ labels[..., muscle_idx] = muscle_mask
128
+ labels[..., imat_idx] = imat_mask
129
+
130
+ return labels
131
+
132
+
133
+ def postprocess(xs: np.ndarray, ys: np.ndarray):
134
+ """Built-in post-processing.
135
+
136
+ TODO: Make this configurable.
137
+
138
+ Args:
139
+ xs (ndarray): NxHxW
140
+ ys (ndarray): NxHxWxC
141
+ params (dictionary): Post-processing parameters. Must contain
142
+ "categories".
143
+
144
+ Returns:
145
+ ndarray: Post-processed labels.
146
+ """
147
+
148
+ # Add another channel full of zeros to ys
149
+ ys = np.concatenate([ys, np.zeros_like(ys[..., :1])], axis=-1)
150
+
151
+ # If muscle hu is < -30, assume it is imat.
152
+
153
+ """
154
+ if "muscle" in categories and "imat" in categories:
155
+ ys = _swap_muscle_imap(
156
+ xs,
157
+ ys,
158
+ muscle_idx=categories["muscle"],
159
+ imat_idx=categories["imat"],
160
+ )
161
+ """
162
+
163
+ return ys
164
+
165
+
166
+ def predict(
167
+ model,
168
+ dataset: Dataset,
169
+ batch_size: int = 16,
170
+ num_workers: int = 1,
171
+ max_queue_size: int = 10,
172
+ use_multiprocessing: bool = False,
173
+ ):
174
+ """Predict segmentation masks for a dataset.
175
+
176
+ Args:
177
+ model (keras.Model): Model to use for prediction.
178
+ dataset (Dataset): Dataset to predict on.
179
+ batch_size (int): Batch size.
180
+ num_workers (int): Number of workers.
181
+ max_queue_size (int): Maximum queue size.
182
+ use_multiprocessing (bool): Use multiprocessing.
183
+ use_postprocessing (bool): Use built-in post-processing.
184
+ postprocessing_params (dict): Post-processing parameters.
185
+
186
+ Returns:
187
+ List: List of segmentation masks.
188
+ """
189
+
190
+ if num_workers > 0:
191
+ enqueuer = OrderedEnqueuer(
192
+ dataset, use_multiprocessing=use_multiprocessing, shuffle=False
193
+ )
194
+ enqueuer.start(workers=num_workers, max_queue_size=max_queue_size)
195
+ output_generator = enqueuer.get()
196
+ else:
197
+ output_generator = iter(dataset)
198
+
199
+ num_scans = len(dataset)
200
+ xs = []
201
+ ys = []
202
+ params = []
203
+ for _ in tqdm(range(num_scans)):
204
+ x, p_dicts = next(output_generator)
205
+ y = model.predict(x, batch_size=batch_size)
206
+
207
+ image = np.stack([out["image"] for out in p_dicts], axis=0)
208
+ y = postprocess(image, y)
209
+
210
+ params.extend(p_dicts)
211
+ xs.extend([x[i, ...] for i in range(len(x))])
212
+ ys.extend([y[i, ...] for i in range(len(y))])
213
+
214
+ return xs, ys, params
Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ from pathlib import Path
4
+ from time import perf_counter
5
+ from typing import List, Union
6
+
7
+ import cv2
8
+ import h5py
9
+ import nibabel as nib
10
+ import numpy as np
11
+ import pandas as pd
12
+ import wget
13
+ from keras import backend as K
14
+ from tqdm import tqdm
15
+
16
+ from comp2comp.inference_class_base import InferenceClass
17
+ from comp2comp.metrics.metrics import CrossSectionalArea, HounsfieldUnits
18
+ from comp2comp.models.models import Models
19
+ from comp2comp.muscle_adipose_tissue.data import Dataset, predict
20
+
21
+
22
+ class MuscleAdiposeTissueSegmentation(InferenceClass):
23
+ """Muscle adipose tissue segmentation class."""
24
+
25
+ def __init__(self, batch_size: int, model_name: str, model_dir: str = None):
26
+ super().__init__()
27
+ self.batch_size = batch_size
28
+ self.model_name = model_name
29
+ self.model_type = Models.model_from_name(model_name)
30
+
31
+ def forward_pass_2d(self, files):
32
+ dataset = Dataset(files, windows=self.model_type.windows)
33
+ num_workers = 1
34
+
35
+ print("Computing segmentation masks using {}...".format(self.model_name))
36
+ start_time = perf_counter()
37
+ _, preds, results = predict(
38
+ self.model,
39
+ dataset,
40
+ num_workers=num_workers,
41
+ use_multiprocessing=num_workers > 1,
42
+ batch_size=self.batch_size,
43
+ )
44
+ K.clear_session()
45
+ print(
46
+ f"Completed {len(files)} segmentations in {(perf_counter() - start_time):.2f} seconds."
47
+ )
48
+ for i in range(len(results)):
49
+ results[i]["preds"] = preds[i]
50
+ return results
51
+
52
+ def download_muscle_adipose_tissue_model(self, model_dir: Union[str, Path]):
53
+ download_dir = Path(
54
+ os.path.join(
55
+ model_dir,
56
+ ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1",
57
+ )
58
+ )
59
+ all_path = download_dir / "all"
60
+ if not os.path.exists(all_path):
61
+ download_dir.mkdir(parents=True, exist_ok=True)
62
+ wget.download(
63
+ "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/all.zip",
64
+ out=os.path.join(download_dir, "all.zip"),
65
+ )
66
+ with zipfile.ZipFile(os.path.join(download_dir, "all.zip"), "r") as zip_ref:
67
+ zip_ref.extractall(download_dir)
68
+ os.remove(os.path.join(download_dir, "all.zip"))
69
+ wget.download(
70
+ "https://huggingface.co/stanfordmimi/multilevel_muscle_adipose_tissue/resolve/main/plans.pkl",
71
+ out=os.path.join(download_dir, "plans.pkl"),
72
+ )
73
+ print("Muscle and adipose tissue model downloaded.")
74
+ else:
75
+ print("Muscle and adipose tissue model already downloaded.")
76
+
77
+ def __call__(self, inference_pipeline):
78
+ inference_pipeline.muscle_adipose_tissue_model_type = self.model_type
79
+ inference_pipeline.muscle_adipose_tissue_model_name = self.model_name
80
+
81
+ if self.model_name == "stanford_v0.0.2":
82
+ self.download_muscle_adipose_tissue_model(inference_pipeline.model_dir)
83
+ nifti_path = os.path.join(
84
+ inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz"
85
+ )
86
+ output_path = os.path.join(
87
+ inference_pipeline.output_dir,
88
+ "segmentations",
89
+ "converted_dcm_seg.nii.gz",
90
+ )
91
+
92
+ from nnunet.inference import predict
93
+
94
+ predict.predict_cases(
95
+ model=os.path.join(
96
+ inference_pipeline.model_dir,
97
+ ".totalsegmentator/nnunet/results/nnUNet/2d/Task927_FatMuscle/nnUNetTrainerV2__nnUNetPlansv2.1",
98
+ ),
99
+ list_of_lists=[[nifti_path]],
100
+ output_filenames=[output_path],
101
+ folds="all",
102
+ save_npz=False,
103
+ num_threads_preprocessing=8,
104
+ num_threads_nifti_save=8,
105
+ segs_from_prev_stage=None,
106
+ do_tta=False,
107
+ mixed_precision=True,
108
+ overwrite_existing=False,
109
+ all_in_gpu=False,
110
+ step_size=0.5,
111
+ checkpoint_name="model_final_checkpoint",
112
+ segmentation_export_kwargs=None,
113
+ )
114
+
115
+ image_nib = nib.load(nifti_path)
116
+ image_nib = nib.as_closest_canonical(image_nib)
117
+ image = image_nib.get_fdata()
118
+ pred = nib.load(output_path)
119
+ pred = nib.as_closest_canonical(pred)
120
+ pred = pred.get_fdata()
121
+
122
+ images = [image[:, :, i] for i in range(image.shape[-1])]
123
+ preds = [pred[:, :, i] for i in range(pred.shape[-1])]
124
+
125
+ # flip both axes and transpose
126
+ images = [np.flip(np.flip(image, axis=0), axis=1).T for image in images]
127
+ preds = [np.flip(np.flip(pred, axis=0), axis=1).T for pred in preds]
128
+
129
+ spacings = [
130
+ image_nib.header.get_zooms()[0:2] for i in range(image.shape[-1])
131
+ ]
132
+
133
+ categories = self.model_type.categories
134
+
135
+ # for each image in images, convert to one hot encoding
136
+ masks = []
137
+ for pred in preds:
138
+ mask = np.zeros((pred.shape[0], pred.shape[1], 4))
139
+ for i, category in enumerate(categories):
140
+ mask[:, :, i] = pred == categories[category]
141
+ mask = mask.astype(np.uint8)
142
+ masks.append(mask)
143
+ return {"images": images, "preds": masks, "spacings": spacings}
144
+
145
+ else:
146
+ dicom_file_paths = inference_pipeline.dicom_file_paths
147
+ # if dicom_file_names not an attribute of inference_pipeline, add it
148
+ if not hasattr(inference_pipeline, "dicom_file_names"):
149
+ inference_pipeline.dicom_file_names = [
150
+ dicom_file_path.stem for dicom_file_path in dicom_file_paths
151
+ ]
152
+ self.model = self.model_type.load_model(inference_pipeline.model_dir)
153
+
154
+ results = self.forward_pass_2d(dicom_file_paths)
155
+ images = []
156
+ for result in results:
157
+ images.append(result["image"])
158
+ preds = []
159
+ for result in results:
160
+ preds.append(result["preds"])
161
+ spacings = []
162
+ for result in results:
163
+ spacings.append(result["spacing"])
164
+
165
+ return {"images": images, "preds": preds, "spacings": spacings}
166
+
167
+
168
+ class MuscleAdiposeTissuePostProcessing(InferenceClass):
169
+ """Post-process muscle and adipose tissue segmentation."""
170
+
171
+ def __init__(self):
172
+ super().__init__()
173
+
174
+ def preds_to_mask(self, preds):
175
+ """Convert model predictions to a mask.
176
+
177
+ Args:
178
+ preds (np.ndarray): Model predictions.
179
+
180
+ Returns:
181
+ np.ndarray: Mask.
182
+ """
183
+ if self.use_softmax:
184
+ # softmax
185
+ labels = np.zeros_like(preds, dtype=np.uint8)
186
+ l_argmax = np.argmax(preds, axis=-1)
187
+ for c in range(labels.shape[-1]):
188
+ labels[l_argmax == c, c] = 1
189
+ return labels.astype(np.bool)
190
+ else:
191
+ # sigmoid
192
+ return preds >= 0.5
193
+
194
+ def __call__(self, inference_pipeline, images, preds, spacings):
195
+ """Post-process muscle and adipose tissue segmentation."""
196
+ self.model_type = inference_pipeline.muscle_adipose_tissue_model_type
197
+ self.use_softmax = self.model_type.use_softmax
198
+ self.model_name = inference_pipeline.muscle_adipose_tissue_model_name
199
+ return self.post_process(images, preds, spacings)
200
+
201
+ def remove_small_objects(self, mask, min_size=10):
202
+ mask = mask.astype(np.uint8)
203
+ components, output, stats, centroids = cv2.connectedComponentsWithStats(
204
+ mask, connectivity=8
205
+ )
206
+ sizes = stats[1:, -1]
207
+ mask = np.zeros((output.shape))
208
+ for i in range(0, components - 1):
209
+ if sizes[i] >= min_size:
210
+ mask[output == i + 1] = 1
211
+ return mask
212
+
213
+ def post_process(
214
+ self,
215
+ images,
216
+ preds,
217
+ spacings,
218
+ ):
219
+ categories = self.model_type.categories
220
+
221
+ start_time = perf_counter()
222
+
223
+ if self.model_name == "stanford_v0.0.2":
224
+ masks = preds
225
+ else:
226
+ masks = [self.preds_to_mask(p) for p in preds]
227
+
228
+ for i, _ in enumerate(masks):
229
+ # Keep only channels from the model_type categories dict
230
+ masks[i] = np.squeeze(masks[i])
231
+
232
+ masks = self.fill_holes(masks)
233
+
234
+ cats = list(categories.keys())
235
+
236
+ file_idx = 0
237
+ for mask, image in tqdm(zip(masks, images), total=len(masks)):
238
+ muscle_mask = mask[..., cats.index("muscle")]
239
+ imat_mask = mask[..., cats.index("imat")]
240
+ imat_mask = (
241
+ np.logical_and(
242
+ (image * muscle_mask) <= -30, (image * muscle_mask) >= -190
243
+ )
244
+ ).astype(int)
245
+ imat_mask = self.remove_small_objects(imat_mask)
246
+ mask[..., cats.index("imat")] += imat_mask
247
+ mask[..., cats.index("muscle")][imat_mask == 1] = 0
248
+ masks[file_idx] = mask
249
+ images[file_idx] = image
250
+ file_idx += 1
251
+
252
+ print(
253
+ f"Completed post-processing in {(perf_counter() - start_time):.2f} seconds."
254
+ )
255
+
256
+ return {"images": images, "masks": masks, "spacings": spacings}
257
+
258
+ # function that fills in holes in a segmentation mask
259
+ def _fill_holes(self, mask: np.ndarray, mask_id: int):
260
+ """Fill in holes in a segmentation mask.
261
+
262
+ Args:
263
+ mask (ndarray): NxHxW
264
+ mask_id (int): Label of the mask.
265
+
266
+ Returns:
267
+ ndarray: Filled mask.
268
+ """
269
+ int_mask = ((1 - mask) > 0.5).astype(np.int8)
270
+ components, output, stats, _ = cv2.connectedComponentsWithStats(
271
+ int_mask, connectivity=8
272
+ )
273
+ sizes = stats[1:, -1]
274
+ components = components - 1
275
+ # Larger threshold for SAT
276
+ # TODO make this configurable / parameter
277
+ if mask_id == 2:
278
+ min_size = 200
279
+ else:
280
+ # min_size = 50 # Smaller threshold for everything else
281
+ min_size = 20
282
+ img_out = np.ones_like(mask)
283
+ for i in range(0, components):
284
+ if sizes[i] > min_size:
285
+ img_out[output == i + 1] = 0
286
+ return img_out
287
+
288
+ def fill_holes(self, ys: List):
289
+ """Take an array of size NxHxWxC and for each channel fill in holes.
290
+
291
+ Args:
292
+ ys (list): List of segmentation masks.
293
+ """
294
+ segs = []
295
+ for n in range(len(ys)):
296
+ ys_out = [
297
+ self._fill_holes(ys[n][..., i], i) for i in range(ys[n].shape[-1])
298
+ ]
299
+ segs.append(np.stack(ys_out, axis=2).astype(float))
300
+
301
+ return segs
302
+
303
+
304
+ class MuscleAdiposeTissueComputeMetrics(InferenceClass):
305
+ """Compute muscle and adipose tissue metrics."""
306
+
307
+ def __init__(self):
308
+ super().__init__()
309
+
310
+ def __call__(self, inference_pipeline, images, masks, spacings):
311
+ """Compute muscle and adipose tissue metrics."""
312
+ self.model_type = inference_pipeline.muscle_adipose_tissue_model_type
313
+ self.model_name = inference_pipeline.muscle_adipose_tissue_model_name
314
+ metrics = self.compute_metrics_all(images, masks, spacings)
315
+ return metrics
316
+
317
+ def compute_metrics_all(self, images, masks, spacings):
318
+ """Compute metrics for all images and masks.
319
+
320
+ Args:
321
+ images (List[np.ndarray]): Images.
322
+ masks (List[np.ndarray]): Masks.
323
+
324
+ Returns:
325
+ Dict: Results.
326
+ """
327
+ results = []
328
+ for image, mask, spacing in zip(images, masks, spacings):
329
+ results.append(self.compute_metrics(image, mask, spacing))
330
+ return {"images": images, "results": results}
331
+
332
+ def compute_metrics(self, x, mask, spacing):
333
+ """Compute results for a given segmentation."""
334
+ categories = self.model_type.categories
335
+
336
+ hu = HounsfieldUnits()
337
+ csa_units = "cm^2" if spacing else ""
338
+ csa = CrossSectionalArea(csa_units)
339
+
340
+ hu_vals = hu(mask, x, category_dim=-1)
341
+ csa_vals = csa(mask=mask, spacing=spacing, category_dim=-1)
342
+
343
+ # check if any values are nan and replace with 0
344
+ hu_vals = np.nan_to_num(hu_vals)
345
+ csa_vals = np.nan_to_num(csa_vals)
346
+
347
+ assert mask.shape[-1] == len(
348
+ categories
349
+ ), "{} categories found in mask, " "but only {} categories specified".format(
350
+ mask.shape[-1], len(categories)
351
+ )
352
+
353
+ results = {
354
+ cat: {
355
+ "mask": mask[..., idx],
356
+ hu.name(): hu_vals[idx],
357
+ csa.name(): csa_vals[idx],
358
+ }
359
+ for idx, cat in enumerate(categories.keys())
360
+ }
361
+ return results
362
+
363
+
364
+ class MuscleAdiposeTissueH5Saver(InferenceClass):
365
+ """Save results to an HDF5 file."""
366
+
367
+ def __init__(self):
368
+ super().__init__()
369
+
370
+ def __call__(self, inference_pipeline, results):
371
+ """Save results to an HDF5 file."""
372
+ self.model_type = inference_pipeline.muscle_adipose_tissue_model_type
373
+ self.model_name = inference_pipeline.muscle_adipose_tissue_model_name
374
+ self.output_dir = inference_pipeline.output_dir
375
+ self.h5_output_dir = os.path.join(self.output_dir, "segmentations")
376
+ os.makedirs(self.h5_output_dir, exist_ok=True)
377
+ self.dicom_file_paths = inference_pipeline.dicom_file_paths
378
+ self.dicom_file_names = inference_pipeline.dicom_file_names
379
+ self.save_results(results)
380
+ return {"results": results}
381
+
382
+ def save_results(self, results):
383
+ """Save results to an HDF5 file."""
384
+ categories = self.model_type.categories
385
+ cats = list(categories.keys())
386
+
387
+ for i, result in enumerate(results):
388
+ file_name = self.dicom_file_names[i]
389
+ with h5py.File(
390
+ os.path.join(self.h5_output_dir, file_name + ".h5"), "w"
391
+ ) as f:
392
+ for cat in cats:
393
+ mask = result[cat]["mask"]
394
+ f.create_dataset(name=cat, data=np.array(mask, dtype=np.uint8))
395
+
396
+
397
+ class MuscleAdiposeTissueMetricsSaver(InferenceClass):
398
+ """Save metrics to a CSV file."""
399
+
400
+ def __init__(self):
401
+ super().__init__()
402
+
403
+ def __call__(self, inference_pipeline, results):
404
+ """Save metrics to a CSV file."""
405
+ self.model_type = inference_pipeline.muscle_adipose_tissue_model_type
406
+ self.model_name = inference_pipeline.muscle_adipose_tissue_model_name
407
+ self.output_dir = inference_pipeline.output_dir
408
+ self.csv_output_dir = os.path.join(self.output_dir, "metrics")
409
+ os.makedirs(self.csv_output_dir, exist_ok=True)
410
+ self.dicom_file_paths = inference_pipeline.dicom_file_paths
411
+ self.dicom_file_names = inference_pipeline.dicom_file_names
412
+ self.save_results(results)
413
+ return {}
414
+
415
+ def save_results(self, results):
416
+ """Save results to a CSV file."""
417
+ self.model_type.categories
418
+ df = pd.DataFrame(
419
+ columns=[
420
+ "Level",
421
+ "Index",
422
+ "Muscle HU",
423
+ "Muscle CSA (cm^2)",
424
+ "SAT HU",
425
+ "SAT CSA (cm^2)",
426
+ "VAT HU",
427
+ "VAT CSA (cm^2)",
428
+ "IMAT HU",
429
+ "IMAT CSA (cm^2)",
430
+ ]
431
+ )
432
+
433
+ for i, result in enumerate(results):
434
+ row = []
435
+ row.append(self.dicom_file_names[i])
436
+ row.append(self.dicom_file_paths[i])
437
+ for cat in result:
438
+ row.append(result[cat]["Hounsfield Unit"])
439
+ row.append(result[cat]["Cross-sectional Area (cm^2)"])
440
+ df.loc[i] = row
441
+ df = df.iloc[::-1]
442
+ df.to_csv(
443
+ os.path.join(self.csv_output_dir, "muscle_adipose_tissue_metrics.csv"),
444
+ index=False,
445
+ )
Comp2Comp-main/comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+
10
+ from comp2comp.inference_class_base import InferenceClass
11
+ from comp2comp.visualization.detectron_visualizer import Visualizer
12
+
13
+
14
+ class MuscleAdiposeTissueVisualizer(InferenceClass):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ self._spine_colors = {
19
+ "L5": [255, 0, 0],
20
+ "L4": [0, 255, 0],
21
+ "L3": [255, 255, 0],
22
+ "L2": [255, 128, 0],
23
+ "L1": [0, 255, 255],
24
+ "T12": [255, 0, 255],
25
+ }
26
+
27
+ self._muscle_fat_colors = {
28
+ "muscle": [255, 136, 133],
29
+ "imat": [154, 135, 224],
30
+ "vat": [140, 197, 135],
31
+ "sat": [246, 190, 129],
32
+ }
33
+
34
+ self._SPINE_TEXT_OFFSET_FROM_TOP = 10.0
35
+ self._SPINE_TEXT_OFFSET_FROM_RIGHT = 63.0
36
+ self._SPINE_TEXT_VERTICAL_SPACING = 14.0
37
+
38
+ self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING = 40.0
39
+ self._MUSCLE_FAT_TEXT_VERTICAL_SPACING = 14.0
40
+ self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP = 22.0
41
+ self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT = 181.0
42
+
43
+ def __call__(self, inference_pipeline, images, results):
44
+ self.output_dir = inference_pipeline.output_dir
45
+ self.dicom_file_names = inference_pipeline.dicom_file_names
46
+ # if spine is an attribute of the inference pipeline, use it
47
+ if not hasattr(inference_pipeline, "spine"):
48
+ spine = False
49
+ else:
50
+ spine = True
51
+
52
+ for i, (image, result) in enumerate(zip(images, results)):
53
+ # now, result is a dict with keys for each tissue
54
+ dicom_file_name = self.dicom_file_names[i]
55
+ self.save_binary_segmentation_overlay(image, result, dicom_file_name, spine)
56
+ # pass along for next class in pipeline
57
+ return {"results": results}
58
+
59
+ def save_binary_segmentation_overlay(self, image, result, dicom_file_name, spine):
60
+ file_name = dicom_file_name + ".png"
61
+ img_in = image
62
+ assert img_in.shape == (512, 512), "Image shape is not 512 x 512"
63
+
64
+ img_in = np.clip(img_in, -300, 1800)
65
+ img_in = self.normalize_img(img_in) * 255.0
66
+
67
+ # Create the folder to save the images
68
+ images_base_path = Path(self.output_dir) / "images"
69
+ images_base_path.mkdir(exist_ok=True)
70
+
71
+ text_start_vertical_offset = self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP
72
+
73
+ img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1))
74
+ img_rgb = np.tile(img_in, (1, 1, 3))
75
+
76
+ vis = Visualizer(img_rgb)
77
+ vis.draw_text(
78
+ text="Density (HU)",
79
+ position=(
80
+ img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63,
81
+ text_start_vertical_offset,
82
+ ),
83
+ color=[1, 1, 1],
84
+ font_size=9,
85
+ horizontal_alignment="left",
86
+ )
87
+ vis.draw_text(
88
+ text="Area (CM²)",
89
+ position=(
90
+ img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63,
91
+ text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING,
92
+ ),
93
+ color=[1, 1, 1],
94
+ font_size=9,
95
+ horizontal_alignment="left",
96
+ )
97
+
98
+ if spine:
99
+ spine_color = np.array(self._spine_colors[dicom_file_name]) / 255.0
100
+ vis.draw_box(
101
+ box_coord=(1, 1, img_in.shape[0] - 1, img_in.shape[1] - 1),
102
+ alpha=1,
103
+ edge_color=spine_color,
104
+ )
105
+ # draw the level T12 - L5 in the upper left corner
106
+ if dicom_file_name == "T12":
107
+ position = (40, 15)
108
+ else:
109
+ position = (30, 15)
110
+ vis.draw_text(
111
+ text=dicom_file_name, position=position, color=spine_color, font_size=24
112
+ )
113
+
114
+ for idx, tissue in enumerate(result.keys()):
115
+ alpha_val = 0.9
116
+ color = np.array(self._muscle_fat_colors[tissue]) / 255.0
117
+ edge_color = color
118
+ mask = result[tissue]["mask"]
119
+
120
+ vis.draw_binary_mask(
121
+ mask,
122
+ color=color,
123
+ edge_color=edge_color,
124
+ alpha=alpha_val,
125
+ area_threshold=0,
126
+ )
127
+
128
+ hu_val = round(result[tissue]["Hounsfield Unit"])
129
+ area_val = round(result[tissue]["Cross-sectional Area (cm^2)"])
130
+
131
+ vis.draw_text(
132
+ text=tissue,
133
+ position=(
134
+ mask.shape[1]
135
+ - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT
136
+ + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1),
137
+ text_start_vertical_offset - self._MUSCLE_FAT_TEXT_VERTICAL_SPACING,
138
+ ),
139
+ color=color,
140
+ font_size=9,
141
+ horizontal_alignment="center",
142
+ )
143
+
144
+ vis.draw_text(
145
+ text=hu_val,
146
+ position=(
147
+ mask.shape[1]
148
+ - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT
149
+ + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1),
150
+ text_start_vertical_offset,
151
+ ),
152
+ color=color,
153
+ font_size=9,
154
+ horizontal_alignment="center",
155
+ )
156
+ vis.draw_text(
157
+ text=area_val,
158
+ position=(
159
+ mask.shape[1]
160
+ - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT
161
+ + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1),
162
+ text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING,
163
+ ),
164
+ color=color,
165
+ font_size=9,
166
+ horizontal_alignment="center",
167
+ )
168
+
169
+ vis_obj = vis.get_output()
170
+ vis_obj.save(os.path.join(images_base_path, file_name))
171
+
172
+ def normalize_img(self, img: np.ndarray) -> np.ndarray:
173
+ """Normalize the image.
174
+
175
+ Args:
176
+ img (np.ndarray): Input image.
177
+
178
+ Returns:
179
+ np.ndarray: Normalized image.
180
+ """
181
+ return (img - img.min()) / (img.max() - img.min())
Comp2Comp-main/comp2comp/spine/spine.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import math
6
+ import os
7
+ import shutil
8
+ import zipfile
9
+ from pathlib import Path
10
+ from time import time
11
+ from typing import Union
12
+
13
+ import nibabel as nib
14
+ import numpy as np
15
+ import pandas as pd
16
+ import wget
17
+ from PIL import Image
18
+ from totalsegmentatorv2.python_api import totalsegmentator
19
+
20
+ from comp2comp.inference_class_base import InferenceClass
21
+ from comp2comp.io import io_utils
22
+ from comp2comp.models.models import Models
23
+ from comp2comp.spine import spine_utils
24
+ from comp2comp.visualization.dicom import to_dicom
25
+
26
+ # from totalsegmentator.libs import (
27
+ # download_pretrained_weights,
28
+ # nostdout,
29
+ # setup_nnunet,
30
+ # )
31
+
32
+
33
+
34
+
35
+ class SpineSegmentation(InferenceClass):
36
+ """Spine segmentation."""
37
+
38
+ def __init__(self, model_name, save=True):
39
+ super().__init__()
40
+ self.model_name = model_name
41
+ self.save_segmentations = save
42
+
43
+ def __call__(self, inference_pipeline):
44
+ # inference_pipeline.dicom_series_path = self.input_path
45
+ self.output_dir = inference_pipeline.output_dir
46
+ self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
47
+ if not os.path.exists(self.output_dir_segmentations):
48
+ os.makedirs(self.output_dir_segmentations)
49
+
50
+ self.model_dir = inference_pipeline.model_dir
51
+
52
+ # seg, mv = self.spine_seg(
53
+ # os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
54
+ # self.output_dir_segmentations + "spine.nii.gz",
55
+ # inference_pipeline.model_dir,
56
+ # )
57
+ os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir
58
+
59
+ seg = totalsegmentator(
60
+ input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
61
+ output=os.path.join(self.output_dir_segmentations, "segmentation.nii"),
62
+ task_ids=[292],
63
+ ml=True,
64
+ nr_thr_resamp=1,
65
+ nr_thr_saving=6,
66
+ fast=False,
67
+ nora_tag="None",
68
+ preview=False,
69
+ task="total",
70
+ # roi_subset=[
71
+ # "vertebrae_T12",
72
+ # "vertebrae_L1",
73
+ # "vertebrae_L2",
74
+ # "vertebrae_L3",
75
+ # "vertebrae_L4",
76
+ # "vertebrae_L5",
77
+ # ],
78
+ roi_subset=None,
79
+ statistics=False,
80
+ radiomics=False,
81
+ crop_path=None,
82
+ body_seg=False,
83
+ force_split=False,
84
+ output_type="nifti",
85
+ quiet=False,
86
+ verbose=False,
87
+ test=0,
88
+ skip_saving=True,
89
+ device="gpu",
90
+ license_number=None,
91
+ statistics_exclude_masks_at_border=True,
92
+ no_derived_masks=False,
93
+ v1_order=False,
94
+ )
95
+ mv = nib.load(
96
+ os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz")
97
+ )
98
+
99
+ # inference_pipeline.segmentation = nib.load(
100
+ # os.path.join(self.output_dir_segmentations, "segmentation.nii")
101
+ # )
102
+ inference_pipeline.segmentation = seg
103
+ inference_pipeline.medical_volume = mv
104
+ inference_pipeline.save_segmentations = self.save_segmentations
105
+ return {}
106
+
107
+ def setup_nnunet_c2c(self, model_dir: Union[str, Path]):
108
+ """Adapted from TotalSegmentator."""
109
+
110
+ model_dir = Path(model_dir)
111
+ config_dir = model_dir / Path("." + self.model_name)
112
+ (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir(
113
+ exist_ok=True, parents=True
114
+ )
115
+ (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True)
116
+ weights_dir = config_dir / "nnunet/results"
117
+ self.weights_dir = weights_dir
118
+
119
+ os.environ["nnUNet_raw_data_base"] = str(
120
+ weights_dir
121
+ ) # not needed, just needs to be an existing directory
122
+ os.environ["nnUNet_preprocessed"] = str(
123
+ weights_dir
124
+ ) # not needed, just needs to be an existing directory
125
+ os.environ["RESULTS_FOLDER"] = str(weights_dir)
126
+
127
+ def download_spine_model(self, model_dir: Union[str, Path]):
128
+ download_dir = Path(
129
+ os.path.join(
130
+ self.weights_dir,
131
+ "nnUNet/3d_fullres/Task252_Spine/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1",
132
+ )
133
+ )
134
+ fold_0_path = download_dir / "fold_0"
135
+ if not os.path.exists(fold_0_path):
136
+ download_dir.mkdir(parents=True, exist_ok=True)
137
+ wget.download(
138
+ "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/fold_0.zip",
139
+ out=os.path.join(download_dir, "fold_0.zip"),
140
+ )
141
+ with zipfile.ZipFile(
142
+ os.path.join(download_dir, "fold_0.zip"), "r"
143
+ ) as zip_ref:
144
+ zip_ref.extractall(download_dir)
145
+ os.remove(os.path.join(download_dir, "fold_0.zip"))
146
+ wget.download(
147
+ "https://huggingface.co/louisblankemeier/spine_v1/resolve/main/plans.pkl",
148
+ out=os.path.join(download_dir, "plans.pkl"),
149
+ )
150
+ print("Spine model downloaded.")
151
+ else:
152
+ print("Spine model already downloaded.")
153
+
154
+ def spine_seg(
155
+ self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
156
+ ):
157
+ """Run spine segmentation.
158
+
159
+ Args:
160
+ input_path (Union[str, Path]): Input path.
161
+ output_path (Union[str, Path]): Output path.
162
+ """
163
+
164
+ print("Segmenting spine...")
165
+ st = time()
166
+ os.environ["SCRATCH"] = self.model_dir
167
+ os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir
168
+
169
+ # Setup nnunet
170
+ model = "3d_fullres"
171
+ folds = [0]
172
+ trainer = "nnUNetTrainerV2_ep4000_nomirror"
173
+ crop_path = None
174
+ task_id = [252]
175
+
176
+ if self.model_name == "ts_spine":
177
+ setup_nnunet()
178
+ download_pretrained_weights(task_id[0])
179
+ elif self.model_name == "stanford_spine_v0.0.1":
180
+ self.setup_nnunet_c2c(model_dir)
181
+ self.download_spine_model(model_dir)
182
+ else:
183
+ raise ValueError("Invalid model name.")
184
+
185
+ if not self.save_segmentations:
186
+ output_path = None
187
+
188
+ from totalsegmentator.nnunet import nnUNet_predict_image
189
+
190
+ with nostdout():
191
+ img, seg = nnUNet_predict_image(
192
+ input_path,
193
+ output_path,
194
+ task_id,
195
+ model=model,
196
+ folds=folds,
197
+ trainer=trainer,
198
+ tta=False,
199
+ multilabel_image=True,
200
+ resample=1.5,
201
+ crop=None,
202
+ crop_path=crop_path,
203
+ task_name="total",
204
+ nora_tag="None",
205
+ preview=False,
206
+ nr_threads_resampling=1,
207
+ nr_threads_saving=6,
208
+ quiet=False,
209
+ verbose=False,
210
+ test=0,
211
+ )
212
+ end = time()
213
+
214
+ # Log total time for spine segmentation
215
+ print(f"Total time for spine segmentation: {end-st:.2f}s.")
216
+
217
+ if self.model_name == "stanford_spine_v0.0.1":
218
+ seg_data = seg.get_fdata()
219
+ # subtract 17 from seg values except for 0
220
+ seg_data = np.where(seg_data == 0, 0, seg_data - 17)
221
+ seg = nib.Nifti1Image(seg_data, seg.affine, seg.header)
222
+
223
+ return seg, img
224
+
225
+
226
+ class AxialCropper(InferenceClass):
227
+ """Crop the CT image (medical_volume) and segmentation based on user-specified
228
+ lower and upper levels of the spine.
229
+ """
230
+
231
+ def __init__(self, lower_level: str = "L5", upper_level: str = "L1", save=True):
232
+ """
233
+ Args:
234
+ lower_level (str, optional): Lower level of the spine. Defaults to "L5".
235
+ upper_level (str, optional): Upper level of the spine. Defaults to "L1".
236
+ save (bool, optional): Save cropped image and segmentation. Defaults to True.
237
+
238
+ Raises:
239
+ ValueError: If lower_level or upper_level is not a valid spine level.
240
+ """
241
+ super().__init__()
242
+ self.lower_level = lower_level
243
+ self.upper_level = upper_level
244
+ ts_spine_full_model = Models.model_from_name("ts_spine_full")
245
+ categories = ts_spine_full_model.categories
246
+ try:
247
+ self.lower_level_index = categories[self.lower_level]
248
+ self.upper_level_index = categories[self.upper_level]
249
+ except KeyError:
250
+ raise ValueError("Invalid spine level.") from None
251
+ self.save = save
252
+
253
+ def __call__(self, inference_pipeline):
254
+ """
255
+ First dim goes from L to R.
256
+ Second dim goes from P to A.
257
+ Third dim goes from I to S.
258
+ """
259
+ segmentation = inference_pipeline.segmentation
260
+ segmentation_data = segmentation.get_fdata()
261
+ upper_level_index = np.where(segmentation_data == self.upper_level_index)[
262
+ 2
263
+ ].max()
264
+ lower_level_index = np.where(segmentation_data == self.lower_level_index)[
265
+ 2
266
+ ].min()
267
+ segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index]
268
+ inference_pipeline.segmentation = segmentation
269
+
270
+ medical_volume = inference_pipeline.medical_volume
271
+ medical_volume = medical_volume.slicer[
272
+ :, :, lower_level_index:upper_level_index
273
+ ]
274
+ inference_pipeline.medical_volume = medical_volume
275
+
276
+ if self.save:
277
+ nib.save(
278
+ segmentation,
279
+ os.path.join(
280
+ inference_pipeline.output_dir, "segmentations", "spine.nii.gz"
281
+ ),
282
+ )
283
+ nib.save(
284
+ medical_volume,
285
+ os.path.join(
286
+ inference_pipeline.output_dir,
287
+ "segmentations",
288
+ "converted_dcm.nii.gz",
289
+ ),
290
+ )
291
+ return {}
292
+
293
+
294
+ class SpineComputeROIs(InferenceClass):
295
+ def __init__(self, spine_model):
296
+ super().__init__()
297
+ self.spine_model_name = spine_model
298
+ self.spine_model_type = Models.model_from_name(self.spine_model_name)
299
+
300
+ def __call__(self, inference_pipeline):
301
+ # Compute ROIs
302
+ inference_pipeline.spine_model_type = self.spine_model_type
303
+
304
+ (spine_hus, rois, segmentation_hus, centroids_3d) = spine_utils.compute_rois(
305
+ inference_pipeline.segmentation,
306
+ inference_pipeline.medical_volume,
307
+ self.spine_model_type,
308
+ )
309
+
310
+ inference_pipeline.spine_hus = spine_hus
311
+ inference_pipeline.segmentation_hus = segmentation_hus
312
+ inference_pipeline.rois = rois
313
+ inference_pipeline.centroids_3d = centroids_3d
314
+
315
+ return {}
316
+
317
+
318
+ class SpineMetricsSaver(InferenceClass):
319
+ """Save metrics to a CSV file."""
320
+
321
+ def __init__(self):
322
+ super().__init__()
323
+
324
+ def __call__(self, inference_pipeline):
325
+ """Save metrics to a CSV file."""
326
+ self.spine_hus = inference_pipeline.spine_hus
327
+ self.seg_hus = inference_pipeline.segmentation_hus
328
+ self.output_dir = inference_pipeline.output_dir
329
+ self.csv_output_dir = os.path.join(self.output_dir, "metrics")
330
+ if not os.path.exists(self.csv_output_dir):
331
+ os.makedirs(self.csv_output_dir, exist_ok=True)
332
+ self.save_results()
333
+ if hasattr(inference_pipeline, "dicom_ds"):
334
+ if not os.path.exists(os.path.join(self.output_dir, "dicom_metadata.csv")):
335
+ io_utils.write_dicom_metadata_to_csv(
336
+ inference_pipeline.dicom_ds,
337
+ os.path.join(self.output_dir, "dicom_metadata.csv"),
338
+ )
339
+
340
+ return {}
341
+
342
+ def save_results(self):
343
+ """Save results to a CSV file."""
344
+ df = pd.DataFrame(columns=["Level", "ROI HU", "Seg HU"])
345
+ for i, level in enumerate(self.spine_hus):
346
+ hu = self.spine_hus[level]
347
+ seg_hu = self.seg_hus[level]
348
+ row = [level, hu, seg_hu]
349
+ df.loc[i] = row
350
+ df = df.iloc[::-1]
351
+ df.to_csv(os.path.join(self.csv_output_dir, "spine_metrics.csv"), index=False)
352
+
353
+
354
+ class SpineFindDicoms(InferenceClass):
355
+ def __init__(self):
356
+ super().__init__()
357
+
358
+ def __call__(self, inference_pipeline):
359
+ inferior_superior_centers = spine_utils.find_spine_dicoms(
360
+ inference_pipeline.centroids_3d,
361
+ )
362
+
363
+ spine_utils.save_nifti_select_slices(
364
+ inference_pipeline.output_dir, inferior_superior_centers
365
+ )
366
+ inference_pipeline.dicom_file_paths = [
367
+ str(center) for center in inferior_superior_centers
368
+ ]
369
+ inference_pipeline.names = list(inference_pipeline.rois.keys())
370
+ inference_pipeline.dicom_file_names = list(inference_pipeline.rois.keys())
371
+ inference_pipeline.inferior_superior_centers = inferior_superior_centers
372
+
373
+ return {}
374
+
375
+
376
+ class SpineCoronalSagittalVisualizer(InferenceClass):
377
+ def __init__(self, format="png"):
378
+ super().__init__()
379
+ self.format = format
380
+
381
+ def __call__(self, inference_pipeline):
382
+ output_path = inference_pipeline.output_dir
383
+ spine_model_type = inference_pipeline.spine_model_type
384
+
385
+ img_sagittal, img_coronal = spine_utils.visualize_coronal_sagittal_spine(
386
+ inference_pipeline.segmentation.get_fdata(),
387
+ list(inference_pipeline.rois.values()),
388
+ inference_pipeline.medical_volume.get_fdata(),
389
+ list(inference_pipeline.centroids_3d.values()),
390
+ output_path,
391
+ spine_hus=inference_pipeline.spine_hus,
392
+ seg_hus=inference_pipeline.segmentation_hus,
393
+ model_type=spine_model_type,
394
+ pixel_spacing=inference_pipeline.pixel_spacing_list,
395
+ format=self.format,
396
+ )
397
+ inference_pipeline.spine_vis_sagittal = img_sagittal
398
+ inference_pipeline.spine_vis_coronal = img_coronal
399
+ inference_pipeline.spine = True
400
+ if not inference_pipeline.save_segmentations:
401
+ shutil.rmtree(os.path.join(output_path, "segmentations"))
402
+ return {}
403
+
404
+
405
+ class SpineReport(InferenceClass):
406
+ def __init__(self, format="png"):
407
+ super().__init__()
408
+ self.format = format
409
+
410
+ def __call__(self, inference_pipeline):
411
+ sagittal_image = inference_pipeline.spine_vis_sagittal
412
+ coronal_image = inference_pipeline.spine_vis_coronal
413
+ # concatenate these numpy arrays laterally
414
+ img = np.concatenate((coronal_image, sagittal_image), axis=1)
415
+ output_path = os.path.join(
416
+ inference_pipeline.output_dir, "images", "spine_report"
417
+ )
418
+ if self.format == "png":
419
+ im = Image.fromarray(img)
420
+ im.save(output_path + ".png")
421
+ elif self.format == "dcm":
422
+ to_dicom(img, output_path + ".dcm")
423
+ return {}
424
+
425
+
426
+ class SpineMuscleAdiposeTissueReport(InferenceClass):
427
+ """Spine muscle adipose tissue report class."""
428
+
429
+ def __init__(self):
430
+ super().__init__()
431
+ self.image_files = [
432
+ "spine_coronal.png",
433
+ "spine_sagittal.png",
434
+ "T12.png",
435
+ "L1.png",
436
+ "L2.png",
437
+ "L3.png",
438
+ "L4.png",
439
+ "L5.png",
440
+ ]
441
+
442
+ def __call__(self, inference_pipeline):
443
+ image_dir = Path(inference_pipeline.output_dir) / "images"
444
+ self.generate_panel(image_dir)
445
+ return {}
446
+
447
+ def generate_panel(self, image_dir: Union[str, Path]):
448
+ """Generate panel.
449
+ Args:
450
+ image_dir (Union[str, Path]): Path to the image directory.
451
+ """
452
+ image_files = [os.path.join(image_dir, path) for path in self.image_files]
453
+ # construct a list which includes only the images that exist
454
+ image_files = [path for path in image_files if os.path.exists(path)]
455
+
456
+ im_cor = Image.open(image_files[0])
457
+ im_sag = Image.open(image_files[1])
458
+ im_cor_width = int(im_cor.width / im_cor.height * 512)
459
+ num_muscle_fat_cols = math.ceil((len(image_files) - 2) / 2)
460
+ width = (8 + im_cor_width + 8) + ((512 + 8) * num_muscle_fat_cols)
461
+ height = 1048
462
+ new_im = Image.new("RGB", (width, height))
463
+
464
+ index = 2
465
+ for j in range(8, height, 520):
466
+ for i in range(8 + im_cor_width + 8, width, 520):
467
+ try:
468
+ im = Image.open(image_files[index])
469
+ im.thumbnail((512, 512))
470
+ new_im.paste(im, (i, j))
471
+ index += 1
472
+ im.close()
473
+ except Exception:
474
+ continue
475
+
476
+ im_cor.thumbnail((im_cor_width, 512))
477
+ new_im.paste(im_cor, (8, 8))
478
+ im_sag.thumbnail((im_cor_width, 512))
479
+ new_im.paste(im_sag, (8, 528))
480
+ new_im.save(os.path.join(image_dir, "spine_muscle_adipose_tissue_report.png"))
481
+ im_cor.close()
482
+ im_sag.close()
483
+ new_im.close()
Comp2Comp-main/comp2comp/spine/spine_utils.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import logging
6
+ import math
7
+ import os
8
+ from typing import Dict, List
9
+
10
+ import cv2
11
+ import nibabel as nib
12
+ import numpy as np
13
+ from scipy.ndimage import zoom
14
+
15
+ from comp2comp.spine import spine_visualization
16
+
17
+
18
+ def find_spine_dicoms(centroids: Dict): # , path: str, levels):
19
+ """Find the dicom files corresponding to the spine T12 - L5 levels."""
20
+
21
+ vertical_positions = []
22
+ for level in centroids:
23
+ centroid = centroids[level]
24
+ vertical_positions.append(round(centroid[2]))
25
+
26
+ # dicom_files = []
27
+ # ipps = []
28
+ # for dicom_path in glob(path + "/*.dcm"):
29
+ # ipp = dcmread(dicom_path).ImagePositionPatient
30
+ # ipps.append(ipp[2])
31
+ # dicom_files.append(dicom_path)
32
+
33
+ # dicom_files = [x for _, x in sorted(zip(ipps, dicom_files))]
34
+ # dicom_files = list(np.array(dicom_files)[vertical_positions])
35
+
36
+ # return (dicom_files, levels, vertical_positions)
37
+ return vertical_positions
38
+
39
+
40
+ def save_nifti_select_slices(output_dir: str, vertical_positions):
41
+ nifti_path = os.path.join(output_dir, "segmentations", "converted_dcm.nii.gz")
42
+ nifti_in = nib.load(nifti_path)
43
+ nifti_np = nifti_in.get_fdata()
44
+ nifti_np = nifti_np[:, :, vertical_positions]
45
+ nifti_out = nib.Nifti1Image(nifti_np, nifti_in.affine, nifti_in.header)
46
+ # save the nifti
47
+ nifti_output_path = os.path.join(
48
+ output_dir, "segmentations", "converted_dcm.nii.gz"
49
+ )
50
+ nib.save(nifti_out, nifti_output_path)
51
+
52
+
53
+ # Function that takes a numpy array as input, computes the
54
+ # sagittal centroid of each label and returns a list of the
55
+ # centroids
56
+ def compute_centroids(seg: np.ndarray, spine_model_type):
57
+ """Compute the centroids of the labels.
58
+
59
+ Args:
60
+ seg (np.ndarray): Segmentation volume.
61
+ spine_model_type (str): Model type.
62
+
63
+ Returns:
64
+ List[int]: List of centroids.
65
+ """
66
+ # take values of spine_model_type.categories dictionary
67
+ # and convert to list
68
+ centroids = {}
69
+ for level in spine_model_type.categories:
70
+ label_idx = spine_model_type.categories[level]
71
+ try:
72
+ pos = compute_centroid(seg, "sagittal", label_idx)
73
+ centroids[level] = pos
74
+ except Exception:
75
+ logging.warning(f"Label {level} not found in segmentation volume.")
76
+ return centroids
77
+
78
+
79
+ # Function that takes a numpy array as input, as well as a list of centroids,
80
+ # takes a slice through the centroid on axis = 1 for each centroid
81
+ # and returns a list of the slices
82
+ def get_slices(seg: np.ndarray, centroids: Dict, spine_model_type):
83
+ """Get the slices corresponding to the centroids.
84
+
85
+ Args:
86
+ seg (np.ndarray): Segmentation volume.
87
+ centroids (List[int]): List of centroids.
88
+ spine_model_type (str): Model type.
89
+
90
+ Returns:
91
+ List[np.ndarray]: List of slices.
92
+ """
93
+ seg = seg.astype(np.uint8)
94
+ slices = {}
95
+ for level in centroids:
96
+ label_idx = spine_model_type.categories[level]
97
+ binary_seg = (seg[centroids[level], :, :] == label_idx).astype(int)
98
+ if (
99
+ np.sum(binary_seg) > 200
100
+ ): # heuristic to make sure enough of the body is showing
101
+ slices[level] = binary_seg
102
+ return slices
103
+
104
+
105
+ # Function that takes a mask and for each deletes the right most
106
+ # connected component. Returns the mask with the right most
107
+ # connected component deleted
108
+ def delete_right_most_connected_component(mask: np.ndarray):
109
+ """Delete the right most connected component corresponding to spinous processes.
110
+
111
+ Args:
112
+ mask (np.ndarray): Mask volume.
113
+
114
+ Returns:
115
+ np.ndarray: Mask volume.
116
+ """
117
+ mask = mask.astype(np.uint8)
118
+ _, labels, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
119
+ right_most_connected_component = np.argmin(centroids[1:, 1]) + 1
120
+ mask[labels == right_most_connected_component] = 0
121
+ return mask
122
+
123
+
124
+ # compute center of mass of 2d mask
125
+ def compute_center_of_mass(mask: np.ndarray):
126
+ """Compute the center of mass of a 2D mask.
127
+
128
+ Args:
129
+ mask (np.ndarray): Mask volume.
130
+
131
+ Returns:
132
+ np.ndarray: Center of mass.
133
+ """
134
+ mask = mask.astype(np.uint8)
135
+ _, _, _, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
136
+ center_of_mass = np.mean(centroids[1:, :], axis=0)
137
+ return center_of_mass
138
+
139
+
140
+ # Function that takes a 3d centroid and retruns a binary mask with a 3d
141
+ # roi around the centroid
142
+ def roi_from_mask(img, centroid: np.ndarray):
143
+ """Compute a 3D ROI from a 3D mask.
144
+
145
+ Args:
146
+ img (np.ndarray): Image volume.
147
+ centroid (np.ndarray): Centroid.
148
+
149
+ Returns:
150
+ np.ndarray: ROI volume.
151
+ """
152
+ roi = np.zeros(img.shape)
153
+
154
+ img_np = img.get_fdata()
155
+
156
+ pixel_spacing = img.header.get_zooms()
157
+ length_i = 5.0 / pixel_spacing[0]
158
+ length_j = 5.0 / pixel_spacing[1]
159
+ length_k = 5.0 / pixel_spacing[2]
160
+
161
+ print(
162
+ f"Computing ROI with centroid {centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f} "
163
+ f"and pixel spacing "
164
+ f"{pixel_spacing[0]:.3f}mm, {pixel_spacing[1]:.3f}mm, {pixel_spacing[2]:.3f}mm..."
165
+ )
166
+
167
+ # cubic ROI around centroid
168
+ """
169
+ roi[
170
+ int(centroid[0] - length) : int(centroid[0] + length),
171
+ int(centroid[1] - length) : int(centroid[1] + length),
172
+ int(centroid[2] - length) : int(centroid[2] + length),
173
+ ] = 1
174
+ """
175
+ # spherical ROI around centroid
176
+ roi = np.zeros(img_np.shape)
177
+ i_lower = math.floor(centroid[0] - length_i)
178
+ j_lower = math.floor(centroid[1] - length_j)
179
+ k_lower = math.floor(centroid[2] - length_k)
180
+ i_lower_idx = 1000
181
+ j_lower_idx = 1000
182
+ k_lower_idx = 1000
183
+ i_upper_idx = 0
184
+ j_upper_idx = 0
185
+ k_upper_idx = 0
186
+ found_pixels = False
187
+ for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1):
188
+ for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1):
189
+ for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1):
190
+ if (i - centroid[0]) ** 2 / length_i**2 + (
191
+ j - centroid[1]
192
+ ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1:
193
+ roi[i, j, k] = 1
194
+ if i < i_lower_idx:
195
+ i_lower_idx = i
196
+ if j < j_lower_idx:
197
+ j_lower_idx = j
198
+ if k < k_lower_idx:
199
+ k_lower_idx = k
200
+ if i > i_upper_idx:
201
+ i_upper_idx = i
202
+ if j > j_upper_idx:
203
+ j_upper_idx = j
204
+ if k > k_upper_idx:
205
+ k_upper_idx = k
206
+ found_pixels = True
207
+ if not found_pixels:
208
+ print("No pixels in ROI!")
209
+ raise ValueError
210
+ print(
211
+ f"Number of pixels included in i, j, and k directions: {i_upper_idx - i_lower_idx + 1}, "
212
+ f"{j_upper_idx - j_lower_idx + 1}, {k_upper_idx - k_lower_idx + 1}"
213
+ )
214
+ return roi
215
+
216
+
217
+ # Function that takes a 3d image and a 3d binary mask and returns that average
218
+ # value of the image inside the mask
219
+ def mean_img_mask(img: np.ndarray, mask: np.ndarray, index: int):
220
+ """Compute the mean of an image inside a mask.
221
+
222
+ Args:
223
+ img (np.ndarray): Image volume.
224
+ mask (np.ndarray): Mask volume.
225
+ rescale_slope (float): Rescale slope.
226
+ rescale_intercept (float): Rescale intercept.
227
+
228
+ Returns:
229
+ float: Mean value.
230
+ """
231
+ img = img.astype(np.float32)
232
+ mask = mask.astype(np.float32)
233
+ img_masked = (img * mask)[mask > 0]
234
+ # mean = (rescale_slope * np.mean(img_masked)) + rescale_intercept
235
+ # median = (rescale_slope * np.median(img_masked)) + rescale_intercept
236
+ mean = np.mean(img_masked)
237
+ return mean
238
+
239
+
240
+ def compute_rois(seg, img, spine_model_type):
241
+ """Compute the ROIs for the spine.
242
+
243
+ Args:
244
+ seg (np.ndarray): Segmentation volume.
245
+ img (np.ndarray): Image volume.
246
+ rescale_slope (float): Rescale slope.
247
+ rescale_intercept (float): Rescale intercept.
248
+ spine_model_type (Models): Model type.
249
+
250
+ Returns:
251
+ spine_hus (List[float]): List of HU values.
252
+ rois (List[np.ndarray]): List of ROIs.
253
+ centroids_3d (List[np.ndarray]): List of centroids.
254
+ """
255
+ seg_np = seg.get_fdata()
256
+ centroids = compute_centroids(seg_np, spine_model_type)
257
+ slices = get_slices(seg_np, centroids, spine_model_type)
258
+ for level in slices:
259
+ slice = slices[level]
260
+ # keep only the two largest connected components
261
+ two_largest, two = keep_two_largest_connected_components(slice)
262
+ if two:
263
+ slices[level] = delete_right_most_connected_component(two_largest)
264
+
265
+ # Compute ROIs
266
+ rois = {}
267
+ spine_hus = {}
268
+ centroids_3d = {}
269
+ segmentation_hus = {}
270
+ for i, level in enumerate(slices):
271
+ slice = slices[level]
272
+ center_of_mass = compute_center_of_mass(slice)
273
+ centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]])
274
+ roi = roi_from_mask(img, centroid)
275
+ image_numpy = img.get_fdata()
276
+ spine_hus[level] = mean_img_mask(image_numpy, roi, i)
277
+ rois[level] = roi
278
+ mask = (seg_np == spine_model_type.categories[level]).astype(int)
279
+ segmentation_hus[level] = mean_img_mask(image_numpy, mask, i)
280
+ centroids_3d[level] = centroid
281
+ return (spine_hus, rois, segmentation_hus, centroids_3d)
282
+
283
+
284
+ def keep_two_largest_connected_components(mask: Dict):
285
+ """Keep the two largest connected components.
286
+
287
+ Args:
288
+ mask (np.ndarray): Mask volume.
289
+
290
+ Returns:
291
+ np.ndarray: Mask volume.
292
+ """
293
+ mask = mask.astype(np.uint8)
294
+ # sort connected components by size
295
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
296
+ mask, connectivity=8
297
+ )
298
+ stats = stats[1:, 4]
299
+ sorted_indices = np.argsort(stats)[::-1]
300
+ # keep only the two largest connected components
301
+ mask = np.zeros(mask.shape)
302
+ mask[labels == sorted_indices[0] + 1] = 1
303
+ two = True
304
+ try:
305
+ mask[labels == sorted_indices[1] + 1] = 1
306
+ except Exception:
307
+ two = False
308
+ return (mask, two)
309
+
310
+
311
+ def compute_centroid(seg: np.ndarray, plane: str, label: int):
312
+ """Compute the centroid of a label in a given plane.
313
+
314
+ Args:
315
+ seg (np.ndarray): Segmentation volume.
316
+ plane (str): Plane.
317
+ label (int): Label.
318
+
319
+ Returns:
320
+ int: Centroid.
321
+ """
322
+ if plane == "axial":
323
+ sum_out_axes = (0, 1)
324
+ sum_axis = 2
325
+ elif plane == "sagittal":
326
+ sum_out_axes = (1, 2)
327
+ sum_axis = 0
328
+ elif plane == "coronal":
329
+ sum_out_axes = (0, 2)
330
+ sum_axis = 1
331
+ sums = np.sum(seg == label, axis=sum_out_axes)
332
+ normalized_sums = sums / np.sum(sums)
333
+ pos = int(np.sum(np.arange(0, seg.shape[sum_axis]) * normalized_sums))
334
+ return pos
335
+
336
+
337
+ def to_one_hot(label: np.ndarray, model_type, spine_hus):
338
+ """Convert a label to one-hot encoding.
339
+
340
+ Args:
341
+ label (np.ndarray): Label volume.
342
+ model_type (Models): Model type.
343
+
344
+ Returns:
345
+ np.ndarray: One-hot encoding volume.
346
+ """
347
+ levels = list(spine_hus.keys())
348
+ levels.reverse()
349
+ one_hot_label = np.zeros((label.shape[0], label.shape[1], len(levels)))
350
+ for i, level in enumerate(levels):
351
+ label_idx = model_type.categories[level]
352
+ one_hot_label[:, :, i] = (label == label_idx).astype(int)
353
+ return one_hot_label
354
+
355
+
356
+ def visualize_coronal_sagittal_spine(
357
+ seg: np.ndarray,
358
+ rois: List[np.ndarray],
359
+ mvs: np.ndarray,
360
+ centroids_3d: np.ndarray,
361
+ output_dir: str,
362
+ spine_hus=None,
363
+ seg_hus=None,
364
+ model_type=None,
365
+ pixel_spacing=None,
366
+ format="png",
367
+ ):
368
+ """Visualize the coronal and sagittal planes of the spine.
369
+
370
+ Args:
371
+ seg (np.ndarray): Segmentation volume.
372
+ rois (List[np.ndarray]): List of ROIs.
373
+ mvs (dm.MedicalVolume): Medical volume.
374
+ centroids (List[int]): List of centroids.
375
+ label_text (List[str]): List of labels.
376
+ output_dir (str): Output directory.
377
+ spine_hus (List[float], optional): List of HU values. Defaults to None.
378
+ model_type (Models, optional): Model type. Defaults to None.
379
+ """
380
+
381
+ sagittal_vals, coronal_vals = curved_planar_reformation(mvs, centroids_3d)
382
+ zoom_factor = pixel_spacing[2] / pixel_spacing[1]
383
+
384
+ sagittal_image = mvs[sagittal_vals, :, range(len(sagittal_vals))]
385
+ sagittal_label = seg[sagittal_vals, :, range(len(sagittal_vals))]
386
+ sagittal_image = zoom(sagittal_image, (zoom_factor, 1), order=3)
387
+ sagittal_label = zoom(sagittal_label, (zoom_factor, 1), order=1).round()
388
+
389
+ one_hot_sag_label = to_one_hot(sagittal_label, model_type, spine_hus)
390
+ for roi in rois:
391
+ one_hot_roi_label = roi[sagittal_vals, :, range(len(sagittal_vals))]
392
+ one_hot_roi_label = zoom(one_hot_roi_label, (zoom_factor, 1), order=1).round()
393
+ one_hot_sag_label = np.concatenate(
394
+ (
395
+ one_hot_sag_label,
396
+ one_hot_roi_label.reshape(
397
+ (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1)
398
+ ),
399
+ ),
400
+ axis=2,
401
+ )
402
+
403
+ coronal_image = mvs[:, coronal_vals, range(len(coronal_vals))]
404
+ coronal_label = seg[:, coronal_vals, range(len(coronal_vals))]
405
+ coronal_image = zoom(coronal_image, (1, zoom_factor), order=3)
406
+ coronal_label = zoom(coronal_label, (1, zoom_factor), order=1).round()
407
+
408
+ # coronal_image = zoom(coronal_image, (zoom_factor, 1), order=3)
409
+ # coronal_label = zoom(coronal_label, (zoom_factor, 1), order=0).astype(int)
410
+
411
+ one_hot_cor_label = to_one_hot(coronal_label, model_type, spine_hus)
412
+ for roi in rois:
413
+ one_hot_roi_label = roi[:, coronal_vals, range(len(coronal_vals))]
414
+ one_hot_roi_label = zoom(one_hot_roi_label, (1, zoom_factor), order=1).round()
415
+ one_hot_cor_label = np.concatenate(
416
+ (
417
+ one_hot_cor_label,
418
+ one_hot_roi_label.reshape(
419
+ (one_hot_roi_label.shape[0], one_hot_roi_label.shape[1], 1)
420
+ ),
421
+ ),
422
+ axis=2,
423
+ )
424
+
425
+ # flip both axes of coronal image
426
+ sagittal_image = np.flip(sagittal_image, axis=0)
427
+ sagittal_image = np.flip(sagittal_image, axis=1)
428
+
429
+ # flip both axes of coronal label
430
+ one_hot_sag_label = np.flip(one_hot_sag_label, axis=0)
431
+ one_hot_sag_label = np.flip(one_hot_sag_label, axis=1)
432
+
433
+ coronal_image = np.transpose(coronal_image)
434
+ one_hot_cor_label = np.transpose(one_hot_cor_label, (1, 0, 2))
435
+
436
+ # flip both axes of sagittal image
437
+ coronal_image = np.flip(coronal_image, axis=0)
438
+ coronal_image = np.flip(coronal_image, axis=1)
439
+
440
+ # flip both axes of sagittal label
441
+ one_hot_cor_label = np.flip(one_hot_cor_label, axis=0)
442
+ one_hot_cor_label = np.flip(one_hot_cor_label, axis=1)
443
+
444
+ if format == "png":
445
+ sagittal_name = "spine_sagittal.png"
446
+ coronal_name = "spine_coronal.png"
447
+ elif format == "dcm":
448
+ sagittal_name = "spine_sagittal.dcm"
449
+ coronal_name = "spine_coronal.dcm"
450
+ else:
451
+ raise ValueError("Format must be either png or dcm")
452
+
453
+ img_sagittal = spine_visualization.spine_binary_segmentation_overlay(
454
+ sagittal_image,
455
+ one_hot_sag_label,
456
+ output_dir,
457
+ sagittal_name,
458
+ spine_hus=spine_hus,
459
+ seg_hus=seg_hus,
460
+ model_type=model_type,
461
+ pixel_spacing=pixel_spacing,
462
+ )
463
+ img_coronal = spine_visualization.spine_binary_segmentation_overlay(
464
+ coronal_image,
465
+ one_hot_cor_label,
466
+ output_dir,
467
+ coronal_name,
468
+ spine_hus=spine_hus,
469
+ seg_hus=seg_hus,
470
+ model_type=model_type,
471
+ pixel_spacing=pixel_spacing,
472
+ )
473
+
474
+ return img_sagittal, img_coronal
475
+
476
+
477
+ def curved_planar_reformation(mvs, centroids):
478
+ centroids = sorted(centroids, key=lambda x: x[2])
479
+ centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids]
480
+ sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))]
481
+ coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))]
482
+ axial_centroids = [centroids[i][2] for i in range(0, len(centroids))]
483
+ sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0]
484
+ coronal_vals = [coronal_centroids[0]] * axial_centroids[0]
485
+
486
+ for i in range(1, len(axial_centroids)):
487
+ num = axial_centroids[i] - axial_centroids[i - 1]
488
+ interp = list(
489
+ np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num)
490
+ )
491
+ sagittal_vals.extend(interp)
492
+ interp = list(
493
+ np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num)
494
+ )
495
+ coronal_vals.extend(interp)
496
+
497
+ sagittal_vals.extend([sagittal_centroids[-1]] * (mvs.shape[2] - len(sagittal_vals)))
498
+ coronal_vals.extend([coronal_centroids[-1]] * (mvs.shape[2] - len(coronal_vals)))
499
+ sagittal_vals = np.array(sagittal_vals)
500
+ coronal_vals = np.array(coronal_vals)
501
+ sagittal_vals = sagittal_vals.astype(int)
502
+ coronal_vals = coronal_vals.astype(int)
503
+
504
+ return (sagittal_vals, coronal_vals)
505
+
506
+
507
+ '''
508
+ def compare_ts_stanford_centroids(labels_path, pred_centroids):
509
+ """Compare the centroids of the Stanford dataset with the centroids of the TS dataset.
510
+
511
+ Args:
512
+ labels_path (str): Path to the Stanford dataset labels.
513
+ """
514
+ t12_diff = []
515
+ l1_diff = []
516
+ l2_diff = []
517
+ l3_diff = []
518
+ l4_diff = []
519
+ l5_diff = []
520
+ num_skipped = 0
521
+
522
+ labels = glob(labels_path + "/*")
523
+ for label_path in labels:
524
+ # modify label_path to give pred_path
525
+ pred_path = label_path.replace("labelsTs", "predTs_TS")
526
+ print(label_path.split("/")[-1])
527
+ label_nib = nib.load(label_path)
528
+ label = label_nib.get_fdata()
529
+ spacing = label_nib.header.get_zooms()[2]
530
+ pred_nib = nib.load(pred_path)
531
+ pred = pred_nib.get_fdata()
532
+ if True:
533
+ pred[pred == 18] = 6
534
+ pred[pred == 19] = 5
535
+ pred[pred == 20] = 4
536
+ pred[pred == 21] = 3
537
+ pred[pred == 22] = 2
538
+ pred[pred == 23] = 1
539
+
540
+ for label_idx in range(1, 7):
541
+ label_level = label == label_idx
542
+ indexes = np.array(range(label.shape[2]))
543
+ sums = np.sum(label_level, axis=(0, 1))
544
+ normalized_sums = sums / np.sum(sums)
545
+ label_centroid = np.sum(indexes * normalized_sums)
546
+ print(f"Centroid for label {label_idx}: {label_centroid}")
547
+
548
+ if False:
549
+ try:
550
+ pred_centroid = pred_centroids[6 - label_idx]
551
+ except Exception:
552
+ # Change this part
553
+ print("Something wrong with pred_centroids, skipping!")
554
+ num_skipped += 1
555
+ break
556
+
557
+ # if revert_to_original:
558
+ if True:
559
+ pred_level = pred == label_idx
560
+ sums = np.sum(pred_level, axis=(0, 1))
561
+ indices = list(range(sums.shape[0]))
562
+ groupby_input = zip(indices, list(sums))
563
+ g = groupby(groupby_input, key=lambda x: x[1] > 0.0)
564
+ m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1]))
565
+ res = list(zip(*m))
566
+ indexes = list(res[0])
567
+ sums = list(res[1])
568
+ normalized_sums = sums / np.sum(sums)
569
+ pred_centroid = np.sum(indexes * normalized_sums)
570
+ print(f"Centroid for prediction {label_idx}: {pred_centroid}")
571
+
572
+ diff = np.absolute(pred_centroid - label_centroid) * spacing
573
+
574
+ if label_idx == 1:
575
+ t12_diff.append(diff)
576
+ elif label_idx == 2:
577
+ l1_diff.append(diff)
578
+ elif label_idx == 3:
579
+ l2_diff.append(diff)
580
+ elif label_idx == 4:
581
+ l3_diff.append(diff)
582
+ elif label_idx == 5:
583
+ l4_diff.append(diff)
584
+ elif label_idx == 6:
585
+ l5_diff.append(diff)
586
+
587
+ print(f"Skipped {num_skipped}")
588
+ print("The final mean differences in mm:")
589
+ print(
590
+ np.mean(t12_diff),
591
+ np.mean(l1_diff),
592
+ np.mean(l2_diff),
593
+ np.mean(l3_diff),
594
+ np.mean(l4_diff),
595
+ np.mean(l5_diff),
596
+ )
597
+ print("The final median differences in mm:")
598
+ print(
599
+ np.median(t12_diff),
600
+ np.median(l1_diff),
601
+ np.median(l2_diff),
602
+ np.median(l3_diff),
603
+ np.median(l4_diff),
604
+ np.median(l5_diff),
605
+ )
606
+
607
+
608
+ def compare_ts_stanford_roi_hus(image_path):
609
+ """Compare the HU values of the Stanford dataset with the HU values of the TS dataset.
610
+
611
+ image_path (str): Path to the Stanford dataset images.
612
+ """
613
+ img_paths = glob(image_path + "/*")
614
+ differences = np.zeros((40, 6))
615
+ ground_truth = np.zeros((40, 6))
616
+ for i, img_path in enumerate(img_paths):
617
+ print(f"Image number {i + 1}")
618
+ image_path_no_0000 = re.sub(r"_0000", "", img_path)
619
+ ts_seg_path = image_path_no_0000.replace("imagesTs", "predTs_TS")
620
+ stanford_seg_path = image_path_no_0000.replace("imagesTs", "labelsTs")
621
+ img = nib.load(img_path).get_fdata()
622
+ img = np.swapaxes(img, 0, 1)
623
+ ts_seg = nib.load(ts_seg_path).get_fdata()
624
+ ts_seg = np.swapaxes(ts_seg, 0, 1)
625
+ stanford_seg = nib.load(stanford_seg_path).get_fdata()
626
+ stanford_seg = np.swapaxes(stanford_seg, 0, 1)
627
+ ts_model_type = Models.model_from_name("ts_spine")
628
+ (spine_hus_ts, rois, centroids_3d) = compute_rois(ts_seg, img, 1, 0, ts_model_type)
629
+ stanford_model_type = Models.model_from_name("stanford_spine_v0.0.1")
630
+ (spine_hus_stanford, rois, centroids_3d) = compute_rois(
631
+ stanford_seg, img, 1, 0, stanford_model_type
632
+ )
633
+ difference_vals = np.abs(np.array(spine_hus_ts) - np.array(spine_hus_stanford))
634
+ print(f"Differences {difference_vals}\n")
635
+ differences[i, :] = difference_vals
636
+ ground_truth[i, :] = spine_hus_stanford
637
+ print("\n")
638
+ # compute average percent change from ground truth
639
+ percent_change = np.divide(differences, ground_truth) * 100
640
+ average_percent_change = np.mean(percent_change, axis=0)
641
+ median_percent_change = np.median(percent_change, axis=0)
642
+ # print average percent change
643
+ print("Average percent change from ground truth:")
644
+ print(average_percent_change)
645
+ print("Median percent change from ground truth:")
646
+ print(median_percent_change)
647
+ # print average difference
648
+ average_difference = np.mean(differences, axis=0)
649
+ median_difference = np.median(differences, axis=0)
650
+ print("Average difference from ground truth:")
651
+ print(average_difference)
652
+ print("Median difference from ground truth:")
653
+ print(median_difference)
654
+
655
+
656
+ def process_post_hoc(pred_path):
657
+ """Apply post-hoc heuristics for improving Stanford spine model vertical centroid predictions.
658
+
659
+ Args:
660
+ pred_path (str): Path to the prediction.
661
+ """
662
+ pred_nib = nib.load(pred_path)
663
+ pred = pred_nib.get_fdata()
664
+
665
+ pred_bodies = np.logical_and(pred >= 1, pred <= 6)
666
+ pred_bodies = pred_bodies.astype(np.int64)
667
+
668
+ labels_out, N = cc3d.connected_components(pred_bodies, return_N=True, connectivity=6)
669
+
670
+ stats = cc3d.statistics(labels_out)
671
+ print(stats)
672
+
673
+ labels_out_list = []
674
+ voxel_counts_list = list(stats["voxel_counts"])
675
+ for idx_lab in range(1, N + 2):
676
+ labels_out_list.append(labels_out == idx_lab)
677
+
678
+ centroids_list = list(stats["centroids"][:, 2])
679
+
680
+ labels = []
681
+ centroids = []
682
+ voxels = []
683
+
684
+ for idx, count in enumerate(voxel_counts_list):
685
+ if count > 10000:
686
+ labels.append(labels_out_list[idx])
687
+ centroids.append(centroids_list[idx])
688
+ voxels.append(count)
689
+
690
+ top_comps = [
691
+ (counts0, labels0, centroids0)
692
+ for counts0, labels0, centroids0 in sorted(zip(voxels, labels, centroids), reverse=True)
693
+ ]
694
+ top_comps = top_comps[1:7]
695
+
696
+ # ====== Check whether the connected components are fusing vertebral bodies ======
697
+ revert_to_original = False
698
+
699
+ volumes = list(zip(*top_comps))[0]
700
+ if volumes[0] > 1.5 * volumes[1]:
701
+ revert_to_original = True
702
+ print("Reverting to original...")
703
+
704
+ labels = list(zip(*top_comps))[1]
705
+ centroids = list(zip(*top_comps))[2]
706
+
707
+ top_comps = zip(centroids, labels)
708
+ pred_centroids = [x for x, _ in sorted(top_comps)]
709
+
710
+ for label_idx in range(1, 7):
711
+ if not revert_to_original:
712
+ try:
713
+ pred_centroid = pred_centroids[6 - label_idx]
714
+ except:
715
+ # Change this part
716
+ print(
717
+ "Post processing failure, probably < 6 predicted bodies. Reverting to original labels."
718
+ )
719
+ revert_to_original = True
720
+
721
+ if revert_to_original:
722
+ pred_level = pred == label_idx
723
+ sums = np.sum(pred_level, axis=(0, 1))
724
+ indices = list(range(sums.shape[0]))
725
+ groupby_input = zip(indices, list(sums))
726
+ # sys.exit()
727
+ g = groupby(groupby_input, key=lambda x: x[1] > 0.0)
728
+ m = max([list(s) for v, s in g if v > 0], key=lambda x: np.sum(list(zip(*x))[1]))
729
+ # sys.exit()
730
+ # m = max([list(s) for v, s in g], key=lambda np.sum)
731
+ res = list(zip(*m))
732
+ indexes = list(res[0])
733
+ sums = list(res[1])
734
+ normalized_sums = sums / np.sum(sums)
735
+ pred_centroid = np.sum(indexes * normalized_sums)
736
+ print(f"Centroid for prediction {label_idx}: {pred_centroid}")
737
+ '''
Comp2Comp-main/comp2comp/spine/spine_visualization.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+
11
+ from comp2comp.visualization.detectron_visualizer import Visualizer
12
+
13
+
14
+ def spine_binary_segmentation_overlay(
15
+ img_in: Union[str, Path],
16
+ mask: Union[str, Path],
17
+ base_path: Union[str, Path],
18
+ file_name: str,
19
+ figure_text_key=None,
20
+ spine_hus=None,
21
+ seg_hus=None,
22
+ spine=True,
23
+ model_type=None,
24
+ pixel_spacing=None,
25
+ ):
26
+ """Save binary segmentation overlay.
27
+ Args:
28
+ img_in (Union[str, Path]): Path to the input image.
29
+ mask (Union[str, Path]): Path to the mask.
30
+ base_path (Union[str, Path]): Path to the output directory.
31
+ file_name (str): Output file name.
32
+ centroids (list, optional): List of centroids. Defaults to None.
33
+ figure_text_key (dict, optional): Figure text key. Defaults to None.
34
+ spine_hus (list, optional): List of HU values. Defaults to None.
35
+ spine (bool, optional): Spine flag. Defaults to True.
36
+ model_type (Models): Model type. Defaults to None.
37
+ """
38
+ _COLORS = (
39
+ np.array(
40
+ [
41
+ 1.000,
42
+ 0.000,
43
+ 0.000,
44
+ 0.000,
45
+ 1.000,
46
+ 0.000,
47
+ 1.000,
48
+ 1.000,
49
+ 0.000,
50
+ 1.000,
51
+ 0.500,
52
+ 0.000,
53
+ 0.000,
54
+ 1.000,
55
+ 1.000,
56
+ 1.000,
57
+ 0.000,
58
+ 1.000,
59
+ ]
60
+ )
61
+ .astype(np.float32)
62
+ .reshape(-1, 3)
63
+ )
64
+
65
+ label_map = {"L5": 0, "L4": 1, "L3": 2, "L2": 3, "L1": 4, "T12": 5}
66
+
67
+ _ROI_COLOR = np.array([1.000, 0.340, 0.200])
68
+
69
+ _SPINE_TEXT_OFFSET_FROM_TOP = 10.0
70
+ _SPINE_TEXT_OFFSET_FROM_RIGHT = 40.0
71
+ _SPINE_TEXT_VERTICAL_SPACING = 14.0
72
+
73
+ img_in = np.clip(img_in, -300, 1800)
74
+ img_in = normalize_img(img_in) * 255.0
75
+ images_base_path = Path(base_path) / "images"
76
+ images_base_path.mkdir(exist_ok=True)
77
+
78
+ img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1))
79
+ img_rgb = np.tile(img_in, (1, 1, 3))
80
+
81
+ vis = Visualizer(img_rgb)
82
+
83
+ levels = list(spine_hus.keys())
84
+ levels.reverse()
85
+ num_levels = len(levels)
86
+
87
+ # draw seg masks
88
+ for i, level in enumerate(levels):
89
+ color = _COLORS[label_map[level]]
90
+ edge_color = None
91
+ alpha_val = 0.2
92
+ vis.draw_binary_mask(
93
+ mask[:, :, i].astype(int),
94
+ color=color,
95
+ edge_color=edge_color,
96
+ alpha=alpha_val,
97
+ area_threshold=0,
98
+ )
99
+
100
+ # draw rois
101
+ for i, _ in enumerate(levels):
102
+ color = _ROI_COLOR
103
+ edge_color = color
104
+ vis.draw_binary_mask(
105
+ mask[:, :, num_levels + i].astype(int),
106
+ color=color,
107
+ edge_color=edge_color,
108
+ alpha=alpha_val,
109
+ area_threshold=0,
110
+ )
111
+
112
+ vis.draw_text(
113
+ text="ROI",
114
+ position=(
115
+ mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
116
+ _SPINE_TEXT_OFFSET_FROM_TOP,
117
+ ),
118
+ color=[1, 1, 1],
119
+ font_size=9,
120
+ horizontal_alignment="center",
121
+ )
122
+
123
+ vis.draw_text(
124
+ text="Seg",
125
+ position=(
126
+ mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
127
+ _SPINE_TEXT_OFFSET_FROM_TOP,
128
+ ),
129
+ color=[1, 1, 1],
130
+ font_size=9,
131
+ horizontal_alignment="center",
132
+ )
133
+
134
+ # draw text and lines
135
+ for i, level in enumerate(levels):
136
+ vis.draw_text(
137
+ text=f"{level}:",
138
+ position=(
139
+ mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 80,
140
+ _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
141
+ ),
142
+ color=_COLORS[label_map[level]],
143
+ font_size=9,
144
+ horizontal_alignment="left",
145
+ )
146
+ vis.draw_text(
147
+ text=f"{round(float(spine_hus[level]))}",
148
+ position=(
149
+ mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35,
150
+ _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
151
+ ),
152
+ color=_COLORS[label_map[level]],
153
+ font_size=9,
154
+ horizontal_alignment="center",
155
+ )
156
+ vis.draw_text(
157
+ text=f"{round(float(seg_hus[level]))}",
158
+ position=(
159
+ mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT,
160
+ _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP,
161
+ ),
162
+ color=_COLORS[label_map[level]],
163
+ font_size=9,
164
+ horizontal_alignment="center",
165
+ )
166
+
167
+ """
168
+ vis.draw_line(
169
+ x_data=(0, mask.shape[1] - 1),
170
+ y_data=(
171
+ int(
172
+ inferior_superior_centers[num_levels - i - 1]
173
+ * (pixel_spacing[2] / pixel_spacing[1])
174
+ ),
175
+ int(
176
+ inferior_superior_centers[num_levels - i - 1]
177
+ * (pixel_spacing[2] / pixel_spacing[1])
178
+ ),
179
+ ),
180
+ color=_COLORS[label_map[level]],
181
+ linestyle="dashed",
182
+ linewidth=0.25,
183
+ )
184
+ """
185
+
186
+ vis_obj = vis.get_output()
187
+ img = vis_obj.save(os.path.join(images_base_path, file_name))
188
+ return img
189
+
190
+
191
+ def normalize_img(img: np.ndarray) -> np.ndarray:
192
+ """Normalize the image.
193
+ Args:
194
+ img (np.ndarray): Input image.
195
+ Returns:
196
+ np.ndarray: Normalized image.
197
+ """
198
+ return (img - img.min()) / (img.max() - img.min())
Comp2Comp-main/comp2comp/utils/__init__.py ADDED
File without changes
Comp2Comp-main/comp2comp/utils/colormap.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ """
4
+ An awesome colormap for really neat visualizations.
5
+ Copied from Detectron, and removed gray colors.
6
+ """
7
+
8
+ import random
9
+
10
+ import numpy as np
11
+
12
+ __all__ = ["colormap", "random_color", "random_colors"]
13
+
14
+ # fmt: off
15
+ # RGB:
16
+ _COLORS = np.array(
17
+ [
18
+ 0.000, 0.447, 0.741,
19
+ 0.850, 0.325, 0.098,
20
+ 0.929, 0.694, 0.125,
21
+ 0.494, 0.184, 0.556,
22
+ 0.466, 0.674, 0.188,
23
+ 0.301, 0.745, 0.933,
24
+ 0.635, 0.078, 0.184,
25
+ 0.300, 0.300, 0.300,
26
+ 0.600, 0.600, 0.600,
27
+ 1.000, 0.000, 0.000,
28
+ 1.000, 0.500, 0.000,
29
+ 0.749, 0.749, 0.000,
30
+ 0.000, 1.000, 0.000,
31
+ 0.000, 0.000, 1.000,
32
+ 0.667, 0.000, 1.000,
33
+ 0.333, 0.333, 0.000,
34
+ 0.333, 0.667, 0.000,
35
+ 0.333, 1.000, 0.000,
36
+ 0.667, 0.333, 0.000,
37
+ 0.667, 0.667, 0.000,
38
+ 0.667, 1.000, 0.000,
39
+ 1.000, 0.333, 0.000,
40
+ 1.000, 0.667, 0.000,
41
+ 1.000, 1.000, 0.000,
42
+ 0.000, 0.333, 0.500,
43
+ 0.000, 0.667, 0.500,
44
+ 0.000, 1.000, 0.500,
45
+ 0.333, 0.000, 0.500,
46
+ 0.333, 0.333, 0.500,
47
+ 0.333, 0.667, 0.500,
48
+ 0.333, 1.000, 0.500,
49
+ 0.667, 0.000, 0.500,
50
+ 0.667, 0.333, 0.500,
51
+ 0.667, 0.667, 0.500,
52
+ 0.667, 1.000, 0.500,
53
+ 1.000, 0.000, 0.500,
54
+ 1.000, 0.333, 0.500,
55
+ 1.000, 0.667, 0.500,
56
+ 1.000, 1.000, 0.500,
57
+ 0.000, 0.333, 1.000,
58
+ 0.000, 0.667, 1.000,
59
+ 0.000, 1.000, 1.000,
60
+ 0.333, 0.000, 1.000,
61
+ 0.333, 0.333, 1.000,
62
+ 0.333, 0.667, 1.000,
63
+ 0.333, 1.000, 1.000,
64
+ 0.667, 0.000, 1.000,
65
+ 0.667, 0.333, 1.000,
66
+ 0.667, 0.667, 1.000,
67
+ 0.667, 1.000, 1.000,
68
+ 1.000, 0.000, 1.000,
69
+ 1.000, 0.333, 1.000,
70
+ 1.000, 0.667, 1.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.000, 0.000, 0.000,
89
+ 0.143, 0.143, 0.143,
90
+ 0.857, 0.857, 0.857,
91
+ 1.000, 1.000, 1.000
92
+ ]
93
+ ).astype(np.float32).reshape(-1, 3)
94
+ # fmt: on
95
+
96
+
97
+ def colormap(rgb=False, maximum=255):
98
+ """
99
+ Args:
100
+ rgb (bool): whether to return RGB colors or BGR colors.
101
+ maximum (int): either 255 or 1
102
+ Returns:
103
+ ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
104
+ """
105
+ assert maximum in [255, 1], maximum
106
+ c = _COLORS * maximum
107
+ if not rgb:
108
+ c = c[:, ::-1]
109
+ return c
110
+
111
+
112
+ def random_color(rgb=False, maximum=255):
113
+ """
114
+ Args:
115
+ rgb (bool): whether to return RGB colors or BGR colors.
116
+ maximum (int): either 255 or 1
117
+ Returns:
118
+ ndarray: a vector of 3 numbers
119
+ """
120
+ idx = np.random.randint(0, len(_COLORS))
121
+ ret = _COLORS[idx] * maximum
122
+ if not rgb:
123
+ ret = ret[::-1]
124
+ return ret
125
+
126
+
127
+ def random_colors(N, rgb=False, maximum=255):
128
+ """
129
+ Args:
130
+ N (int): number of unique colors needed
131
+ rgb (bool): whether to return RGB colors or BGR colors.
132
+ maximum (int): either 255 or 1
133
+ Returns:
134
+ ndarray: a list of random_color
135
+ """
136
+ indices = random.sample(range(len(_COLORS)), N)
137
+ ret = [_COLORS[i] * maximum for i in indices]
138
+ if not rgb:
139
+ ret = [x[::-1] for x in ret]
140
+ return ret
141
+
142
+
143
+ if __name__ == "__main__":
144
+ import cv2
145
+
146
+ size = 100
147
+ H, W = 10, 10
148
+ canvas = np.random.rand(H * size, W * size, 3).astype("float32")
149
+ for h in range(H):
150
+ for w in range(W):
151
+ idx = h * W + w
152
+ if idx >= len(_COLORS):
153
+ break
154
+ canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
155
+ cv2.imshow("a", canvas)
156
+ cv2.waitKey(0)
Comp2Comp-main/comp2comp/utils/dl_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ from keras import Model
4
+
5
+ # from keras.utils import multi_gpu_model
6
+ # from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
7
+
8
+
9
+ def get_available_gpus(num_gpus: int = None):
10
+ """Get gpu ids for gpus that are >95% free.
11
+
12
+ Tensorflow does not support checking free memory on gpus.
13
+ This is a crude method that relies on `nvidia-smi` to
14
+ determine which gpus are occupied and which are free.
15
+
16
+ Args:
17
+ num_gpus: Number of requested gpus. If not specified,
18
+ ids of all available gpu(s) are returned.
19
+
20
+ Returns:
21
+ List[int]: List of gpu ids that are free. Length
22
+ will equal `num_gpus`, if specified.
23
+ """
24
+ # Built-in tensorflow gpu id.
25
+ assert isinstance(num_gpus, (type(None), int))
26
+ if num_gpus == 0:
27
+ return [-1]
28
+
29
+ num_requested_gpus = num_gpus
30
+ try:
31
+ num_gpus = (
32
+ len(
33
+ subprocess.check_output("nvidia-smi --list-gpus", shell=True)
34
+ .decode()
35
+ .split("\n")
36
+ )
37
+ - 1
38
+ )
39
+
40
+ out_str = subprocess.check_output("nvidia-smi | grep MiB", shell=True).decode()
41
+ except subprocess.CalledProcessError:
42
+ return None
43
+ mem_str = [x for x in out_str.split() if "MiB" in x]
44
+ # First 2 * num_gpu elements correspond to memory for gpus
45
+ # Order: (occupied-0, total-0, occupied-1, total-1, ...)
46
+ mems = [float(x[:-3]) for x in mem_str]
47
+ gpu_percent_occupied_mem = [
48
+ mems[2 * gpu_id] / mems[2 * gpu_id + 1] for gpu_id in range(num_gpus)
49
+ ]
50
+
51
+ available_gpus = [
52
+ gpu_id for gpu_id, mem in enumerate(gpu_percent_occupied_mem) if mem < 0.05
53
+ ]
54
+ if num_requested_gpus and num_requested_gpus > len(available_gpus):
55
+ raise ValueError(
56
+ "Requested {} gpus, only {} are free".format(
57
+ num_requested_gpus, len(available_gpus)
58
+ )
59
+ )
60
+
61
+ return available_gpus[:num_requested_gpus] if num_requested_gpus else available_gpus
62
+
63
+
64
+ class ModelMGPU(Model):
65
+ """Wrapper for distributing model across multiple gpus"""
66
+
67
+ def __init__(self, ser_model, gpus):
68
+ pmodel = multi_gpu_model(ser_model, gpus) # noqa: F821
69
+ self.__dict__.update(pmodel.__dict__)
70
+ self._smodel = ser_model
71
+
72
+ def __getattribute__(self, attrname):
73
+ """Override load and save methods to be used from the serial-model. The
74
+ serial-model holds references to the weights in the multi-gpu model.
75
+ """
76
+ # return Model.__getattribute__(self, attrname)
77
+ if "load" in attrname or "save" in attrname:
78
+ return getattr(self._smodel, attrname)
79
+
80
+ return super(ModelMGPU, self).__getattribute__(attrname)
Comp2Comp-main/comp2comp/utils/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import importlib.util
3
+ import os
4
+ import sys
5
+
6
+ __all__ = []
7
+
8
+
9
+ # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa
10
+ def _import_file(module_name, file_path, make_importable=False):
11
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
12
+ module = importlib.util.module_from_spec(spec)
13
+ spec.loader.exec_module(module)
14
+ if make_importable:
15
+ sys.modules[module_name] = module
16
+ return module
17
+
18
+
19
+ def _configure_libraries():
20
+ """
21
+ Configurations for some libraries.
22
+ """
23
+ # An environment option to disable `import cv2` globally,
24
+ # in case it leads to negative performance impact
25
+ disable_cv2 = int(os.environ.get("MEDSEGPY_DISABLE_CV2", False))
26
+ if disable_cv2:
27
+ sys.modules["cv2"] = None
28
+ else:
29
+ # Disable opencl in opencv since its interaction with cuda often
30
+ # has negative effects
31
+ # This envvar is supported after OpenCV 3.4.0
32
+ os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
33
+ try:
34
+ import cv2
35
+
36
+ if int(cv2.__version__.split(".")[0]) >= 3:
37
+ cv2.ocl.setUseOpenCL(False)
38
+ except ImportError:
39
+ pass
40
+
41
+
42
+ _ENV_SETUP_DONE = False
43
+
44
+
45
+ def setup_environment():
46
+ """Perform environment setup work. The default setup is a no-op, but this
47
+ function allows the user to specify a Python source file or a module in
48
+ the $MEDSEGPY_ENV_MODULE environment variable, that performs
49
+ custom setup work that may be necessary to their computing environment.
50
+ """
51
+ global _ENV_SETUP_DONE
52
+ if _ENV_SETUP_DONE:
53
+ return
54
+ _ENV_SETUP_DONE = True
55
+
56
+ _configure_libraries()
57
+
58
+ custom_module_path = os.environ.get("MEDSEGPY_ENV_MODULE")
59
+
60
+ if custom_module_path:
61
+ setup_custom_environment(custom_module_path)
62
+ else:
63
+ # The default setup is a no-op
64
+ pass
65
+
66
+
67
+ def setup_custom_environment(custom_module):
68
+ """
69
+ Load custom environment setup by importing a Python source file or a
70
+ module, and run the setup function.
71
+ """
72
+ if custom_module.endswith(".py"):
73
+ module = _import_file("medsegpy.utils.env.custom_module", custom_module)
74
+ else:
75
+ module = importlib.import_module(custom_module)
76
+ assert hasattr(module, "setup_environment") and callable(
77
+ module.setup_environment
78
+ ), (
79
+ "Custom environment module defined in {} does not have the "
80
+ "required callable attribute 'setup_environment'."
81
+ ).format(
82
+ custom_module
83
+ )
84
+ module.setup_environment()
Comp2Comp-main/comp2comp/utils/logger.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+ from collections import Counter
8
+
9
+ from termcolor import colored
10
+
11
+ logging.captureWarnings(True)
12
+
13
+
14
+ class _ColorfulFormatter(logging.Formatter):
15
+ def __init__(self, *args, **kwargs):
16
+ self._root_name = kwargs.pop("root_name") + "."
17
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
18
+ if len(self._abbrev_name):
19
+ self._abbrev_name = self._abbrev_name + "."
20
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
21
+
22
+ def formatMessage(self, record):
23
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
24
+ log = super(_ColorfulFormatter, self).formatMessage(record)
25
+ if record.levelno == logging.WARNING:
26
+ prefix = colored("WARNING", "red", attrs=["blink"])
27
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
28
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
29
+ else:
30
+ return log
31
+ return prefix + " " + log
32
+
33
+
34
+ @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa
35
+ def setup_logger(
36
+ output=None,
37
+ distributed_rank=0,
38
+ *,
39
+ color=True,
40
+ name="Comp2Comp",
41
+ abbrev_name=None,
42
+ ):
43
+ """
44
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
45
+
46
+ Args:
47
+ output (str): a file name or a directory to save log. If None, will not
48
+ save log file. If ends with ".txt" or ".log", assumed to be a file
49
+ name. Otherwise, logs will be saved to `output/log.txt`.
50
+ name (str): the root module name of this logger
51
+ abbrev_name (str): an abbreviation of the module, to avoid long names in
52
+ logs. Set to "" to not log the root module in logs.
53
+ By default, will abbreviate "detectron2" to "d2" and leave other
54
+ modules unchanged.
55
+
56
+ Returns:
57
+ logging.Logger: a logger
58
+ """
59
+ logger = logging.getLogger(name)
60
+ logger.setLevel(logging.DEBUG)
61
+ logger.propagate = False
62
+ if abbrev_name is None:
63
+ abbrev_name = name
64
+
65
+ plain_formatter = logging.Formatter(
66
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
67
+ datefmt="%m/%d %H:%M:%S",
68
+ )
69
+ # stdout logging: master only
70
+ if distributed_rank == 0:
71
+ ch = logging.StreamHandler(stream=sys.stdout)
72
+ ch.setLevel(logging.DEBUG)
73
+ if color:
74
+ formatter = _ColorfulFormatter(
75
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
76
+ datefmt="%m/%d %H:%M:%S",
77
+ root_name=name,
78
+ abbrev_name=str(abbrev_name),
79
+ )
80
+ else:
81
+ formatter = plain_formatter
82
+ ch.setFormatter(formatter)
83
+ logger.addHandler(ch)
84
+
85
+ # file logging: all workers
86
+ if output is not None:
87
+ if output.endswith(".txt") or output.endswith(".log"):
88
+ filename = output
89
+ else:
90
+ filename = os.path.join(output, "log.txt")
91
+ if distributed_rank > 0:
92
+ filename = filename + ".rank{}".format(distributed_rank)
93
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
94
+
95
+ fh = logging.StreamHandler(_cached_log_stream(filename))
96
+ fh.setLevel(logging.DEBUG)
97
+ fh.setFormatter(plain_formatter)
98
+ logger.addHandler(fh)
99
+
100
+ return logger
101
+
102
+
103
+ # cache the opened file object, so that different calls to `setup_logger`
104
+ # with the same file name can safely write to the same file.
105
+ @functools.lru_cache(maxsize=None)
106
+ def _cached_log_stream(filename):
107
+ return open(filename, "a")
108
+
109
+
110
+ """
111
+ Below are some other convenient logging methods.
112
+ They are mainly adopted from
113
+ https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
114
+ """
115
+
116
+
117
+ def _find_caller():
118
+ """
119
+ Returns:
120
+ str: module name of the caller
121
+ tuple: a hashable key to be used to identify different callers
122
+ """
123
+ frame = sys._getframe(2)
124
+ while frame:
125
+ code = frame.f_code
126
+ if os.path.join("utils", "logger.") not in code.co_filename:
127
+ mod_name = frame.f_globals["__name__"]
128
+ if mod_name == "__main__":
129
+ mod_name = "detectron2"
130
+ return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
131
+ frame = frame.f_back
132
+
133
+
134
+ _LOG_COUNTER = Counter()
135
+ _LOG_TIMER = {}
136
+
137
+
138
+ def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
139
+ """
140
+ Log only for the first n times.
141
+
142
+ Args:
143
+ lvl (int): the logging level
144
+ msg (str):
145
+ n (int):
146
+ name (str): name of the logger to use. Will use the caller's module by
147
+ default.
148
+ key (str or tuple[str]): the string(s) can be one of "caller" or
149
+ "message", which defines how to identify duplicated logs.
150
+ For example, if called with `n=1, key="caller"`, this function
151
+ will only log the first call from the same caller, regardless of
152
+ the message content.
153
+ If called with `n=1, key="message"`, this function will log the
154
+ same content only once, even if they are called from different
155
+ places.
156
+ If called with `n=1, key=("caller", "message")`, this function
157
+ will not log only if the same caller has logged the same message
158
+ before.
159
+ """
160
+ if isinstance(key, str):
161
+ key = (key,)
162
+ assert len(key) > 0
163
+
164
+ caller_module, caller_key = _find_caller()
165
+ hash_key = ()
166
+ if "caller" in key:
167
+ hash_key = hash_key + caller_key
168
+ if "message" in key:
169
+ hash_key = hash_key + (msg,)
170
+
171
+ _LOG_COUNTER[hash_key] += 1
172
+ if _LOG_COUNTER[hash_key] <= n:
173
+ logging.getLogger(name or caller_module).log(lvl, msg)
174
+
175
+
176
+ def log_every_n(lvl, msg, n=1, *, name=None):
177
+ """
178
+ Log once per n times.
179
+
180
+ Args:
181
+ lvl (int): the logging level
182
+ msg (str):
183
+ n (int):
184
+ name (str): name of the logger to use. Will use the caller's module by
185
+ default.
186
+ """
187
+ caller_module, key = _find_caller()
188
+ _LOG_COUNTER[key] += 1
189
+ if n == 1 or _LOG_COUNTER[key] % n == 1:
190
+ logging.getLogger(name or caller_module).log(lvl, msg)
191
+
192
+
193
+ def log_every_n_seconds(lvl, msg, n=1, *, name=None):
194
+ """
195
+ Log no more than once per n seconds.
196
+
197
+ Args:
198
+ lvl (int): the logging level
199
+ msg (str):
200
+ n (int):
201
+ name (str): name of the logger to use. Will use the caller's module by
202
+ default.
203
+ """
204
+ caller_module, key = _find_caller()
205
+ last_logged = _LOG_TIMER.get(key, None)
206
+ current_time = time.time()
207
+ if last_logged is None or current_time - last_logged >= n:
208
+ logging.getLogger(name or caller_module).log(lvl, msg)
209
+ _LOG_TIMER[key] = current_time
Comp2Comp-main/comp2comp/utils/orientation.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+
3
+ from comp2comp.inference_class_base import InferenceClass
4
+
5
+
6
+ class ToCanonical(InferenceClass):
7
+ """Convert spine segmentation to canonical orientation."""
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def __call__(self, inference_pipeline):
13
+ """
14
+ First dim goes from L to R.
15
+ Second dim goes from P to A.
16
+ Third dim goes from I to S.
17
+ """
18
+ canonical_segmentation = nib.as_closest_canonical(
19
+ inference_pipeline.segmentation
20
+ )
21
+ canonical_medical_volume = nib.as_closest_canonical(
22
+ inference_pipeline.medical_volume
23
+ )
24
+
25
+ inference_pipeline.segmentation = canonical_segmentation
26
+ inference_pipeline.medical_volume = canonical_medical_volume
27
+ inference_pipeline.pixel_spacing_list = (
28
+ canonical_medical_volume.header.get_zooms()
29
+ )
30
+ return {}
Comp2Comp-main/comp2comp/utils/process.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import sys
8
+ import time
9
+ import traceback
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+
13
+ from comp2comp.io import io_utils
14
+
15
+
16
+ def process_2d(args, pipeline_builder):
17
+ output_dir = Path(
18
+ os.path.join(
19
+ os.path.dirname(os.path.abspath(__file__)),
20
+ "../../outputs",
21
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
22
+ )
23
+ )
24
+ if not os.path.exists(output_dir):
25
+ output_dir.mkdir(parents=True)
26
+
27
+ model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models")
28
+ if not os.path.exists(model_dir):
29
+ os.mkdir(model_dir)
30
+
31
+ pipeline = pipeline_builder(args)
32
+
33
+ pipeline(output_dir=output_dir, model_dir=model_dir)
34
+
35
+
36
+ def process_3d(args, pipeline_builder):
37
+ model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models")
38
+ if not os.path.exists(model_dir):
39
+ os.mkdir(model_dir)
40
+
41
+ if args.output_path is not None:
42
+ output_path = Path(args.output_path)
43
+ else:
44
+ output_path = os.path.join(
45
+ os.path.dirname(os.path.abspath(__file__)), "../../outputs"
46
+ )
47
+
48
+ if not args.overwrite_outputs:
49
+ date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
50
+ output_path = os.path.join(output_path, date_time)
51
+
52
+ for path, num in io_utils.get_dicom_or_nifti_paths_and_num(args.input_path):
53
+ try:
54
+ st = time.time()
55
+
56
+ if path.endswith(".nii") or path.endswith(".nii.gz"):
57
+ print("Processing: ", path)
58
+
59
+ else:
60
+ print("Processing: ", path, " with ", num, " slices")
61
+ min_slices = 30
62
+ if num < min_slices:
63
+ print(f"Number of slices is less than {min_slices}, skipping\n")
64
+ continue
65
+
66
+ print("")
67
+
68
+ try:
69
+ sys.stdout.flush()
70
+ except Exception:
71
+ pass
72
+
73
+ if path.endswith(".nii") or path.endswith(".nii.gz"):
74
+ folder_name = Path(os.path.basename(os.path.normpath(path)))
75
+ # remove .nii or .nii.gz
76
+ folder_name = os.path.normpath(
77
+ Path(str(folder_name).replace(".gz", "").replace(".nii", ""))
78
+ )
79
+ output_dir = Path(
80
+ os.path.join(
81
+ output_path,
82
+ folder_name,
83
+ )
84
+ )
85
+
86
+ else:
87
+ output_dir = Path(
88
+ os.path.join(
89
+ output_path,
90
+ Path(os.path.basename(os.path.normpath(args.input_path))),
91
+ os.path.relpath(
92
+ os.path.normpath(path), os.path.normpath(args.input_path)
93
+ ),
94
+ )
95
+ )
96
+
97
+ if not os.path.exists(output_dir):
98
+ output_dir.mkdir(parents=True)
99
+
100
+ pipeline = pipeline_builder(path, args)
101
+
102
+ pipeline(output_dir=output_dir, model_dir=model_dir)
103
+
104
+ if not args.save_segmentations:
105
+ # remove the segmentations folder
106
+ segmentations_dir = os.path.join(output_dir, "segmentations")
107
+ if os.path.exists(segmentations_dir):
108
+ shutil.rmtree(segmentations_dir)
109
+
110
+ print(f"Finished processing {path} in {time.time() - st:.1f} seconds\n")
111
+
112
+ except Exception:
113
+ print(f"ERROR PROCESSING {path}\n")
114
+ traceback.print_exc()
115
+ if os.path.exists(output_dir):
116
+ shutil.rmtree(output_dir)
117
+ # remove parent folder if empty
118
+ if len(os.listdir(os.path.dirname(output_dir))) == 0:
119
+ shutil.rmtree(os.path.dirname(output_dir))
120
+ continue
Comp2Comp-main/comp2comp/utils/run.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ from typing import Sequence, Union
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def format_output_path(
10
+ file_path,
11
+ save_dir: str = None,
12
+ base_dirs: Sequence[str] = None,
13
+ file_name: Sequence[str] = None,
14
+ ):
15
+ """Format output path for a given file.
16
+
17
+ Args:
18
+ file_path (str): File path.
19
+ save_dir (str, optional): Save directory. Defaults to None.
20
+ base_dirs (Sequence[str], optional): Base directories. Defaults to None.
21
+ file_name (Sequence[str], optional): File name. Defaults to None.
22
+
23
+ Returns:
24
+ str: Output path.
25
+ """
26
+
27
+ dirname = os.path.dirname(file_path) if not save_dir else save_dir
28
+
29
+ if save_dir and base_dirs:
30
+ dirname: str = os.path.dirname(file_path)
31
+ relative_dir = [
32
+ dirname.split(bdir, 1)[1] for bdir in base_dirs if dirname.startswith(bdir)
33
+ ][0]
34
+ # Trim path separator from the path
35
+ relative_dir = relative_dir.lstrip(os.path.sep)
36
+ dirname = os.path.join(save_dir, relative_dir)
37
+
38
+ if file_name is not None:
39
+ return os.path.join(
40
+ dirname,
41
+ "{}.h5".format(file_name),
42
+ )
43
+
44
+ return os.path.join(
45
+ dirname,
46
+ "{}.h5".format(os.path.splitext(os.path.basename(file_path))[0]),
47
+ )
48
+
49
+
50
+ # Function the returns a list of file names exluding
51
+ # the extention from the list of file paths
52
+ def get_file_names(files):
53
+ """Get file names from a list of file paths.
54
+
55
+ Args:
56
+ files (list): List of file paths.
57
+
58
+ Returns:
59
+ list: List of file names.
60
+ """
61
+ file_names = []
62
+ for file in files:
63
+ file_name = os.path.splitext(os.path.basename(file))[0]
64
+ file_names.append(file_name)
65
+ return file_names
66
+
67
+
68
+ def find_files(
69
+ root_dirs: Union[str, Sequence[str]],
70
+ max_depth: int = None,
71
+ exist_ok: bool = False,
72
+ pattern: str = None,
73
+ ):
74
+ """Recursively search for files.
75
+
76
+ To avoid recomputing experiments with results, set `exist_ok=False`.
77
+ Results will be searched for in `PREFERENCES.OUTPUT_DIR` (if non-empty).
78
+
79
+ Args:
80
+ root_dirs (`str(s)`): Root folder(s) to search.
81
+ max_depth (int, optional): Maximum depth to search.
82
+ exist_ok (bool, optional): If `True`, recompute results for
83
+ scans.
84
+ pattern (str, optional): If specified, looks for files with names
85
+ matching the pattern.
86
+
87
+ Return:
88
+ List[str]: Experiment directories to test.
89
+ """
90
+
91
+ def _get_files(depth: int, dir_name: str):
92
+ if dir_name is None or not os.path.isdir(dir_name):
93
+ return []
94
+
95
+ if max_depth is not None and depth > max_depth:
96
+ return []
97
+
98
+ files = os.listdir(dir_name)
99
+ ret_files = []
100
+ for file in files:
101
+ possible_dir = os.path.join(dir_name, file)
102
+ if os.path.isdir(possible_dir):
103
+ subfiles = _get_files(depth + 1, possible_dir)
104
+ ret_files.extend(subfiles)
105
+ elif os.path.isfile(possible_dir):
106
+ if pattern and not re.match(pattern, possible_dir):
107
+ continue
108
+ output_path = format_output_path(possible_dir)
109
+ if not exist_ok and os.path.isfile(output_path):
110
+ logger.info(
111
+ "Skipping {} - results exist at {}".format(
112
+ possible_dir, output_path
113
+ )
114
+ )
115
+ continue
116
+ ret_files.append(possible_dir)
117
+
118
+ return ret_files
119
+
120
+ out_files = []
121
+ if isinstance(root_dirs, str):
122
+ root_dirs = [root_dirs]
123
+ for d in root_dirs:
124
+ out_files.extend(_get_files(0, d))
125
+
126
+ return sorted(set(out_files))
Comp2Comp-main/comp2comp/visualization/detectron_visualizer.py ADDED
@@ -0,0 +1,1288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import colorsys
3
+ import logging
4
+ import math
5
+ from enum import Enum, unique
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import matplotlib as mpl
10
+ import matplotlib.colors as mplc
11
+ import matplotlib.figure as mplfigure
12
+ import numpy as np
13
+ import pycocotools.mask as mask_util
14
+ import torch
15
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
16
+
17
+ from comp2comp.utils.colormap import random_color
18
+ from comp2comp.visualization.dicom import to_dicom
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
23
+
24
+
25
+ _SMALL_OBJECT_AREA_THRESH = 1000
26
+ _LARGE_MASK_AREA_THRESH = 120000
27
+ _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
28
+ _BLACK = (0, 0, 0)
29
+ _RED = (1.0, 0, 0)
30
+
31
+ _KEYPOINT_THRESHOLD = 0.05
32
+
33
+
34
+ @unique
35
+ class ColorMode(Enum):
36
+ """
37
+ Enum of different color modes to use for instance visualizations.
38
+ """
39
+
40
+ IMAGE = 0
41
+ """
42
+ Picks a random color for every instance and overlay segmentations with low opacity.
43
+ """
44
+ SEGMENTATION = 1
45
+ """
46
+ Let instances of the same category have similar colors
47
+ (from metadata.thing_colors), and overlay them with
48
+ high opacity. This provides more attention on the quality of segmentation.
49
+ """
50
+ IMAGE_BW = 2
51
+ """
52
+ Same as IMAGE, but convert all areas without masks to gray-scale.
53
+ Only available for drawing per-instance mask predictions.
54
+ """
55
+
56
+
57
+ class GenericMask:
58
+ """
59
+ Attribute:
60
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
61
+ Each ndarray has format [x, y, x, y, ...]
62
+ mask (ndarray): a binary mask
63
+ """
64
+
65
+ def __init__(self, mask_or_polygons, height, width):
66
+ self._mask = self._polygons = self._has_holes = None
67
+ self.height = height
68
+ self.width = width
69
+
70
+ m = mask_or_polygons
71
+ if isinstance(m, dict):
72
+ # RLEs
73
+ assert "counts" in m and "size" in m
74
+ if isinstance(m["counts"], list): # uncompressed RLEs
75
+ h, w = m["size"]
76
+ assert h == height and w == width
77
+ m = mask_util.frPyObjects(m, h, w)
78
+ self._mask = mask_util.decode(m)[:, :]
79
+ return
80
+
81
+ if isinstance(m, list): # list[ndarray]
82
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
83
+ return
84
+
85
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
86
+ assert m.shape[1] != 2, m.shape
87
+ assert m.shape == (
88
+ height,
89
+ width,
90
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
91
+ self._mask = m.astype("uint8")
92
+ return
93
+
94
+ raise ValueError(
95
+ "GenericMask cannot handle object {} of type '{}'".format(m, type(m))
96
+ )
97
+
98
+ @property
99
+ def mask(self):
100
+ if self._mask is None:
101
+ self._mask = self.polygons_to_mask(self._polygons)
102
+ return self._mask
103
+
104
+ @property
105
+ def polygons(self):
106
+ if self._polygons is None:
107
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
108
+ return self._polygons
109
+
110
+ @property
111
+ def has_holes(self):
112
+ if self._has_holes is None:
113
+ if self._mask is not None:
114
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
115
+ else:
116
+ self._has_holes = (
117
+ False # if original format is polygon, does not have holes
118
+ )
119
+ return self._has_holes
120
+
121
+ def mask_to_polygons(self, mask):
122
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
123
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
124
+ # Internal contours (holes) are placed in hierarchy-2.
125
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
126
+ mask = np.ascontiguousarray(
127
+ mask
128
+ ) # some versions of cv2 does not support incontiguous arr
129
+ res = cv2.findContours(
130
+ mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
131
+ )
132
+ hierarchy = res[-1]
133
+ if hierarchy is None: # empty mask
134
+ return [], False
135
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
136
+ res = res[-2]
137
+ res = [x.flatten() for x in res]
138
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
139
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
140
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
141
+ res = [x + 0.5 for x in res if len(x) >= 6]
142
+ return res, has_holes
143
+
144
+ def polygons_to_mask(self, polygons):
145
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
146
+ rle = mask_util.merge(rle)
147
+ return mask_util.decode(rle)[:, :]
148
+
149
+ def area(self):
150
+ return self.mask.sum()
151
+
152
+ def bbox(self):
153
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
154
+ p = mask_util.merge(p)
155
+ bbox = mask_util.toBbox(p)
156
+ bbox[2] += bbox[0]
157
+ bbox[3] += bbox[1]
158
+ return bbox
159
+
160
+
161
+ class _PanopticPrediction:
162
+ """
163
+ Unify different panoptic annotation/prediction formats
164
+ """
165
+
166
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
167
+ if segments_info is None:
168
+ assert metadata is not None
169
+ # If "segments_info" is None, we assume "panoptic_img" is a
170
+ # H*W int32 image storing the panoptic_id in the format of
171
+ # category_id * label_divisor + instance_id. We reserve -1 for
172
+ # VOID label.
173
+ label_divisor = metadata.label_divisor
174
+ segments_info = []
175
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
176
+ if panoptic_label == -1:
177
+ # VOID region.
178
+ continue
179
+ pred_class = panoptic_label // label_divisor
180
+ isthing = (
181
+ pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
182
+ )
183
+ segments_info.append(
184
+ {
185
+ "id": int(panoptic_label),
186
+ "category_id": int(pred_class),
187
+ "isthing": bool(isthing),
188
+ }
189
+ )
190
+ del metadata
191
+
192
+ self._seg = panoptic_seg
193
+
194
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
195
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
196
+ areas = areas.numpy()
197
+ sorted_idxs = np.argsort(-areas)
198
+ self._seg_ids, self._seg_areas = (
199
+ segment_ids[sorted_idxs],
200
+ areas[sorted_idxs],
201
+ )
202
+ self._seg_ids = self._seg_ids.tolist()
203
+ for sid, area in zip(self._seg_ids, self._seg_areas):
204
+ if sid in self._sinfo:
205
+ self._sinfo[sid]["area"] = float(area)
206
+
207
+ def non_empty_mask(self):
208
+ """
209
+ Returns:
210
+ (H, W) array, a mask for all pixels that have a prediction
211
+ """
212
+ empty_ids = []
213
+ for id in self._seg_ids:
214
+ if id not in self._sinfo:
215
+ empty_ids.append(id)
216
+ if len(empty_ids) == 0:
217
+ return np.zeros(self._seg.shape, dtype=np.uint8)
218
+ assert (
219
+ len(empty_ids) == 1
220
+ ), ">1 ids corresponds to no labels. This is currently not supported"
221
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
222
+
223
+ def semantic_masks(self):
224
+ for sid in self._seg_ids:
225
+ sinfo = self._sinfo.get(sid)
226
+ if sinfo is None or sinfo["isthing"]:
227
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
228
+ continue
229
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
230
+
231
+ def instance_masks(self):
232
+ for sid in self._seg_ids:
233
+ sinfo = self._sinfo.get(sid)
234
+ if sinfo is None or not sinfo["isthing"]:
235
+ continue
236
+ mask = (self._seg == sid).numpy().astype(np.bool)
237
+ if mask.sum() > 0:
238
+ yield mask, sinfo
239
+
240
+
241
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
242
+ """
243
+ Args:
244
+ classes (list[int] or None):
245
+ scores (list[float] or None):
246
+ class_names (list[str] or None):
247
+ is_crowd (list[bool] or None):
248
+
249
+ Returns:
250
+ list[str] or None
251
+ """
252
+ labels = None
253
+ if classes is not None:
254
+ if class_names is not None and len(class_names) > 0:
255
+ labels = [class_names[i] for i in classes]
256
+ else:
257
+ labels = [str(i) for i in classes]
258
+ if scores is not None:
259
+ if labels is None:
260
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
261
+ else:
262
+ labels = [
263
+ "{} {:.0f}%".format(lbl, s * 100) for lbl, s in zip(labels, scores)
264
+ ]
265
+ if labels is not None and is_crowd is not None:
266
+ labels = [
267
+ lbl + ("|crowd" if crowd else "") for lbl, crowd in zip(labels, is_crowd)
268
+ ]
269
+ return labels
270
+
271
+
272
+ class VisImage:
273
+ def __init__(self, img, scale=1.0):
274
+ """
275
+ Args:
276
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
277
+ scale (float): scale the input image
278
+ """
279
+ self.img = img
280
+ self.scale = scale
281
+ self.width, self.height = img.shape[1], img.shape[0]
282
+ self._setup_figure(img)
283
+
284
+ def _setup_figure(self, img):
285
+ """
286
+ Args:
287
+ Same as in :meth:`__init__()`.
288
+
289
+ Returns:
290
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
291
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
292
+ """
293
+ fig = mplfigure.Figure(frameon=False)
294
+ self.dpi = fig.get_dpi()
295
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
296
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
297
+ fig.set_size_inches(
298
+ (self.width * self.scale + 1e-2) / self.dpi,
299
+ (self.height * self.scale + 1e-2) / self.dpi,
300
+ )
301
+ self.canvas = FigureCanvasAgg(fig)
302
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
303
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
304
+ ax.axis("off")
305
+ self.fig = fig
306
+ self.ax = ax
307
+ self.reset_image(img)
308
+
309
+ def reset_image(self, img):
310
+ """
311
+ Args:
312
+ img: same as in __init__
313
+ """
314
+ img = img.astype("uint8")
315
+ self.ax.imshow(
316
+ img, extent=(0, self.width, self.height, 0), interpolation="nearest"
317
+ )
318
+
319
+ def save(self, filepath):
320
+ """
321
+ Args:
322
+ filepath (str): a string that contains the absolute path, including the file name, where
323
+ the visualized image will be saved.
324
+ """
325
+ # if filepath is a png or jpg
326
+ img = self.get_image()
327
+ if filepath.endswith(".png") or filepath.endswith(".jpg"):
328
+ self.fig.savefig(filepath)
329
+ if filepath.endswith(".dcm"):
330
+ to_dicom(img, Path(filepath))
331
+ return img
332
+
333
+ def get_image(self):
334
+ """
335
+ Returns:
336
+ ndarray:
337
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
338
+ The shape is scaled w.r.t the input image using the given `scale` argument.
339
+ """
340
+ canvas = self.canvas
341
+ s, (width, height) = canvas.print_to_buffer()
342
+ # buf = io.BytesIO() # works for cairo backend
343
+ # canvas.print_rgba(buf)
344
+ # width, height = self.width, self.height
345
+ # s = buf.getvalue()
346
+
347
+ buffer = np.frombuffer(s, dtype="uint8")
348
+
349
+ img_rgba = buffer.reshape(height, width, 4)
350
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
351
+ return rgb.astype("uint8")
352
+
353
+
354
+ class Visualizer:
355
+ """
356
+ Visualizer that draws data about detection/segmentation on images.
357
+
358
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
359
+ that draw primitive objects to images, as well as high-level wrappers like
360
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
361
+ that draw composite data in some pre-defined style.
362
+
363
+ Note that the exact visualization style for the high-level wrappers are subject to change.
364
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
365
+ of objects themselves (e.g. when the object is too small) may change according
366
+ to different heuristics, as long as the results still look visually reasonable.
367
+
368
+ To obtain a consistent style, you can implement custom drawing functions with the
369
+ abovementioned primitive methods instead. If you need more customized visualization
370
+ styles, you can process the data yourself following their format documented in
371
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
372
+ intend to satisfy everyone's preference on drawing styles.
373
+
374
+ This visualizer focuses on high rendering quality rather than performance. It is not
375
+ designed to be used for real-time applications.
376
+ """
377
+
378
+ # TODO implement a fast, rasterized version using OpenCV
379
+
380
+ def __init__(
381
+ self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE
382
+ ):
383
+ """
384
+ Args:
385
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
386
+ the height and width of the image respectively. C is the number of
387
+ color channels. The image is required to be in RGB format since that
388
+ is a requirement of the Matplotlib library. The image is also expected
389
+ to be in the range [0, 255].
390
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
391
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
392
+ instances on an image.
393
+ """
394
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
395
+ # if metadata is None:
396
+ # metadata = MetadataCatalog.get("__nonexist__")
397
+ self.metadata = metadata
398
+ self.output = VisImage(self.img, scale=scale)
399
+ self.cpu_device = torch.device("cpu")
400
+
401
+ # too small texts are useless, therefore clamp to 9
402
+ self._default_font_size = max(
403
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
404
+ )
405
+ self._instance_mode = instance_mode
406
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
407
+
408
+ def draw_instance_predictions(self, predictions):
409
+ """
410
+ Draw instance-level prediction results on an image.
411
+
412
+ Args:
413
+ predictions (Instances): the output of an instance detection/segmentation
414
+ model. Following fields will be used to draw:
415
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
416
+
417
+ Returns:
418
+ output (VisImage): image object with visualizations.
419
+ """
420
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
421
+ scores = predictions.scores if predictions.has("scores") else None
422
+ classes = (
423
+ predictions.pred_classes.tolist()
424
+ if predictions.has("pred_classes")
425
+ else None
426
+ )
427
+ labels = _create_text_labels(
428
+ classes, scores, self.metadata.get("thing_classes", None)
429
+ )
430
+ keypoints = (
431
+ predictions.pred_keypoints if predictions.has("pred_keypoints") else None
432
+ )
433
+
434
+ if predictions.has("pred_masks"):
435
+ masks = np.asarray(predictions.pred_masks)
436
+ masks = [
437
+ GenericMask(x, self.output.height, self.output.width) for x in masks
438
+ ]
439
+ else:
440
+ masks = None
441
+
442
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get(
443
+ "thing_colors"
444
+ ):
445
+ colors = [
446
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
447
+ for c in classes
448
+ ]
449
+ alpha = 0.8
450
+ else:
451
+ colors = None
452
+ alpha = 0.5
453
+
454
+ if self._instance_mode == ColorMode.IMAGE_BW:
455
+ self.output.reset_image(
456
+ self._create_grayscale_image(
457
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
458
+ if predictions.has("pred_masks")
459
+ else None
460
+ )
461
+ )
462
+ alpha = 0.3
463
+
464
+ self.overlay_instances(
465
+ masks=masks,
466
+ boxes=boxes,
467
+ labels=labels,
468
+ keypoints=keypoints,
469
+ assigned_colors=colors,
470
+ alpha=alpha,
471
+ )
472
+ return self.output
473
+
474
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
475
+ """
476
+ Draw semantic segmentation predictions/labels.
477
+
478
+ Args:
479
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
480
+ Each value is the integer label of the pixel.
481
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
482
+ alpha (float): the larger it is, the more opaque the segmentations are.
483
+
484
+ Returns:
485
+ output (VisImage): image object with visualizations.
486
+ """
487
+ if isinstance(sem_seg, torch.Tensor):
488
+ sem_seg = sem_seg.numpy()
489
+ labels, areas = np.unique(sem_seg, return_counts=True)
490
+ sorted_idxs = np.argsort(-areas).tolist()
491
+ labels = labels[sorted_idxs]
492
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
493
+ try:
494
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
495
+ except (AttributeError, IndexError):
496
+ mask_color = None
497
+
498
+ binary_mask = (sem_seg == label).astype(np.uint8)
499
+ text = self.metadata.stuff_classes[label]
500
+ self.draw_binary_mask(
501
+ binary_mask,
502
+ color=mask_color,
503
+ edge_color=_OFF_WHITE,
504
+ text=text,
505
+ alpha=alpha,
506
+ area_threshold=area_threshold,
507
+ )
508
+ return self.output
509
+
510
+ def draw_panoptic_seg(
511
+ self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7
512
+ ):
513
+ """
514
+ Draw panoptic prediction annotations or results.
515
+
516
+ Args:
517
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
518
+ segment.
519
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
520
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
521
+ If None, category id of each pixel is computed by
522
+ ``pixel // metadata.label_divisor``.
523
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
524
+
525
+ Returns:
526
+ output (VisImage): image object with visualizations.
527
+ """
528
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
529
+
530
+ if self._instance_mode == ColorMode.IMAGE_BW:
531
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
532
+
533
+ # draw mask for all semantic segments first i.e. "stuff"
534
+ for mask, sinfo in pred.semantic_masks():
535
+ category_idx = sinfo["category_id"]
536
+ try:
537
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
538
+ except AttributeError:
539
+ mask_color = None
540
+
541
+ text = self.metadata.stuff_classes[category_idx]
542
+ self.draw_binary_mask(
543
+ mask,
544
+ color=mask_color,
545
+ edge_color=_OFF_WHITE,
546
+ text=text,
547
+ alpha=alpha,
548
+ area_threshold=area_threshold,
549
+ )
550
+
551
+ # draw mask for all instances second
552
+ all_instances = list(pred.instance_masks())
553
+ if len(all_instances) == 0:
554
+ return self.output
555
+ masks, sinfo = list(zip(*all_instances))
556
+ category_ids = [x["category_id"] for x in sinfo]
557
+
558
+ try:
559
+ scores = [x["score"] for x in sinfo]
560
+ except KeyError:
561
+ scores = None
562
+ labels = _create_text_labels(
563
+ category_ids,
564
+ scores,
565
+ self.metadata.thing_classes,
566
+ [x.get("iscrowd", 0) for x in sinfo],
567
+ )
568
+
569
+ try:
570
+ colors = [
571
+ self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
572
+ for c in category_ids
573
+ ]
574
+ except AttributeError:
575
+ colors = None
576
+ self.overlay_instances(
577
+ masks=masks, labels=labels, assigned_colors=colors, alpha=alpha
578
+ )
579
+
580
+ return self.output
581
+
582
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
583
+
584
+ def overlay_instances(
585
+ self,
586
+ *,
587
+ boxes=None,
588
+ labels=None,
589
+ masks=None,
590
+ keypoints=None,
591
+ assigned_colors=None,
592
+ alpha=0.5,
593
+ ):
594
+ """
595
+ Args:
596
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
597
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
598
+ or a :class:`RotatedBoxes`,
599
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
600
+ for the N objects in a single image,
601
+ labels (list[str]): the text to be displayed for each instance.
602
+ masks (masks-like object): Supported types are:
603
+
604
+ * :class:`detectron2.structures.PolygonMasks`,
605
+ :class:`detectron2.structures.BitMasks`.
606
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
607
+ The first level of the list corresponds to individual instances. The second
608
+ level to all the polygon that compose the instance, and the third level
609
+ to the polygon coordinates. The third level should have the format of
610
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
611
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
612
+ * list[dict]: each dict is a COCO-style RLE.
613
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
614
+ where the N is the number of instances and K is the number of keypoints.
615
+ The last dimension corresponds to (x, y, visibility or score).
616
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
617
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
618
+ for full list of formats that the colors are accepted in.
619
+ Returns:
620
+ output (VisImage): image object with visualizations.
621
+ """
622
+ num_instances = 0
623
+ if boxes is not None:
624
+ boxes = self._convert_boxes(boxes)
625
+ num_instances = len(boxes)
626
+ if masks is not None:
627
+ masks = self._convert_masks(masks)
628
+ if num_instances:
629
+ assert len(masks) == num_instances
630
+ else:
631
+ num_instances = len(masks)
632
+ if keypoints is not None:
633
+ if num_instances:
634
+ assert len(keypoints) == num_instances
635
+ else:
636
+ num_instances = len(keypoints)
637
+ keypoints = self._convert_keypoints(keypoints)
638
+ if labels is not None:
639
+ assert len(labels) == num_instances
640
+ if assigned_colors is None:
641
+ assigned_colors = [
642
+ random_color(rgb=True, maximum=1) for _ in range(num_instances)
643
+ ]
644
+ if num_instances == 0:
645
+ return self.output
646
+ if boxes is not None and boxes.shape[1] == 5:
647
+ return self.overlay_rotated_instances(
648
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
649
+ )
650
+
651
+ # Display in largest to smallest order to reduce occlusion.
652
+ areas = None
653
+ if boxes is not None:
654
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
655
+ elif masks is not None:
656
+ areas = np.asarray([x.area() for x in masks])
657
+
658
+ if areas is not None:
659
+ sorted_idxs = np.argsort(-areas).tolist()
660
+ # Re-order overlapped instances in descending order.
661
+ boxes = boxes[sorted_idxs] if boxes is not None else None
662
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
663
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
664
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
665
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
666
+
667
+ for i in range(num_instances):
668
+ color = assigned_colors[i]
669
+ if boxes is not None:
670
+ self.draw_box(boxes[i], edge_color=color)
671
+
672
+ if masks is not None:
673
+ for segment in masks[i].polygons:
674
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
675
+
676
+ if labels is not None:
677
+ # first get a box
678
+ if boxes is not None:
679
+ x0, y0, x1, y1 = boxes[i]
680
+ text_pos = (
681
+ x0,
682
+ y0,
683
+ ) # if drawing boxes, put text on the box corner.
684
+ horiz_align = "left"
685
+ elif masks is not None:
686
+ # skip small mask without polygon
687
+ if len(masks[i].polygons) == 0:
688
+ continue
689
+
690
+ x0, y0, x1, y1 = masks[i].bbox()
691
+
692
+ # draw text in the center (defined by median) when box is not drawn
693
+ # median is less sensitive to outliers.
694
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
695
+ horiz_align = "center"
696
+ else:
697
+ continue # drawing the box confidence for keypoints isn't very useful.
698
+ # for small objects, draw text at the side to avoid occlusion
699
+ instance_area = (y1 - y0) * (x1 - x0)
700
+ if (
701
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
702
+ or y1 - y0 < 40 * self.output.scale
703
+ ):
704
+ if y1 >= self.output.height - 5:
705
+ text_pos = (x1, y0)
706
+ else:
707
+ text_pos = (x0, y1)
708
+
709
+ height_ratio = (y1 - y0) / np.sqrt(
710
+ self.output.height * self.output.width
711
+ )
712
+ lighter_color = self._change_color_brightness(
713
+ color, brightness_factor=0.7
714
+ )
715
+ font_size = (
716
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
717
+ * 0.5
718
+ * self._default_font_size
719
+ )
720
+ self.draw_text(
721
+ labels[i],
722
+ text_pos,
723
+ color=lighter_color,
724
+ horizontal_alignment=horiz_align,
725
+ font_size=font_size,
726
+ )
727
+
728
+ # draw keypoints
729
+ if keypoints is not None:
730
+ for keypoints_per_instance in keypoints:
731
+ self.draw_and_connect_keypoints(keypoints_per_instance)
732
+
733
+ return self.output
734
+
735
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
736
+ """
737
+ Args:
738
+ boxes (ndarray): an Nx5 numpy array of
739
+ (x_center, y_center, width, height, angle_degrees) format
740
+ for the N objects in a single image.
741
+ labels (list[str]): the text to be displayed for each instance.
742
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
743
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
744
+ for full list of formats that the colors are accepted in.
745
+
746
+ Returns:
747
+ output (VisImage): image object with visualizations.
748
+ """
749
+ num_instances = len(boxes)
750
+
751
+ if assigned_colors is None:
752
+ assigned_colors = [
753
+ random_color(rgb=True, maximum=1) for _ in range(num_instances)
754
+ ]
755
+ if num_instances == 0:
756
+ return self.output
757
+
758
+ # Display in largest to smallest order to reduce occlusion.
759
+ if boxes is not None:
760
+ areas = boxes[:, 2] * boxes[:, 3]
761
+
762
+ sorted_idxs = np.argsort(-areas).tolist()
763
+ # Re-order overlapped instances in descending order.
764
+ boxes = boxes[sorted_idxs]
765
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
766
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
767
+
768
+ for i in range(num_instances):
769
+ self.draw_rotated_box_with_label(
770
+ boxes[i],
771
+ edge_color=colors[i],
772
+ label=labels[i] if labels is not None else None,
773
+ )
774
+
775
+ return self.output
776
+
777
+ def draw_and_connect_keypoints(self, keypoints):
778
+ """
779
+ Draws keypoints of an instance and follows the rules for keypoint connections
780
+ to draw lines between appropriate keypoints. This follows color heuristics for
781
+ line color.
782
+
783
+ Args:
784
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
785
+ and the last dimension corresponds to (x, y, probability).
786
+
787
+ Returns:
788
+ output (VisImage): image object with visualizations.
789
+ """
790
+ visible = {}
791
+ keypoint_names = self.metadata.get("keypoint_names")
792
+ for idx, keypoint in enumerate(keypoints):
793
+ # draw keypoint
794
+ x, y, prob = keypoint
795
+ if prob > self.keypoint_threshold:
796
+ self.draw_circle((x, y), color=_RED)
797
+ if keypoint_names:
798
+ keypoint_name = keypoint_names[idx]
799
+ visible[keypoint_name] = (x, y)
800
+
801
+ if self.metadata.get("keypoint_connection_rules"):
802
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
803
+ if kp0 in visible and kp1 in visible:
804
+ x0, y0 = visible[kp0]
805
+ x1, y1 = visible[kp1]
806
+ color = tuple(x / 255.0 for x in color)
807
+ self.draw_line([x0, x1], [y0, y1], color=color)
808
+
809
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
810
+ # Note that this strategy is specific to person keypoints.
811
+ # For other keypoints, it should just do nothing
812
+ try:
813
+ ls_x, ls_y = visible["left_shoulder"]
814
+ rs_x, rs_y = visible["right_shoulder"]
815
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
816
+ except KeyError:
817
+ pass
818
+ else:
819
+ # draw line from nose to mid-shoulder
820
+ nose_x, nose_y = visible.get("nose", (None, None))
821
+ if nose_x is not None:
822
+ self.draw_line(
823
+ [nose_x, mid_shoulder_x],
824
+ [nose_y, mid_shoulder_y],
825
+ color=_RED,
826
+ )
827
+
828
+ try:
829
+ # draw line from mid-shoulder to mid-hip
830
+ lh_x, lh_y = visible["left_hip"]
831
+ rh_x, rh_y = visible["right_hip"]
832
+ except KeyError:
833
+ pass
834
+ else:
835
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
836
+ self.draw_line(
837
+ [mid_hip_x, mid_shoulder_x],
838
+ [mid_hip_y, mid_shoulder_y],
839
+ color=_RED,
840
+ )
841
+ return self.output
842
+
843
+ """
844
+ Primitive drawing functions:
845
+ """
846
+
847
+ def draw_text(
848
+ self,
849
+ text,
850
+ position,
851
+ *,
852
+ font_size=None,
853
+ color="g",
854
+ horizontal_alignment="center",
855
+ rotation=0,
856
+ ):
857
+ """
858
+ Args:
859
+ text (str): class label
860
+ position (tuple): a tuple of the x and y coordinates to place text on image.
861
+ font_size (int, optional): font of the text. If not provided, a font size
862
+ proportional to the image width is calculated and used.
863
+ color: color of the text. Refer to `matplotlib.colors` for full list
864
+ of formats that are accepted.
865
+ horizontal_alignment (str): see `matplotlib.text.Text`
866
+ rotation: rotation angle in degrees CCW
867
+
868
+ Returns:
869
+ output (VisImage): image object with text drawn.
870
+ """
871
+ if not font_size:
872
+ font_size = self._default_font_size
873
+
874
+ # since the text background is dark, we don't want the text to be dark
875
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
876
+ color[np.argmax(color)] = max(0.8, np.max(color))
877
+
878
+ x, y = position
879
+ self.output.ax.text(
880
+ x,
881
+ y,
882
+ text,
883
+ size=font_size * self.output.scale,
884
+ family="sans-serif",
885
+ bbox={
886
+ "facecolor": "black",
887
+ "alpha": 0.8,
888
+ "pad": 0.7,
889
+ "edgecolor": "none",
890
+ },
891
+ verticalalignment="top",
892
+ horizontalalignment=horizontal_alignment,
893
+ color=color,
894
+ zorder=10,
895
+ rotation=rotation,
896
+ )
897
+ return self.output
898
+
899
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
900
+ """
901
+ Args:
902
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
903
+ are the coordinates of the image's top left corner. x1 and y1 are the
904
+ coordinates of the image's bottom right corner.
905
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
906
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
907
+ for full list of formats that are accepted.
908
+ line_style (string): the string to use to create the outline of the boxes.
909
+
910
+ Returns:
911
+ output (VisImage): image object with box drawn.
912
+ """
913
+ x0, y0, x1, y1 = box_coord
914
+ width = x1 - x0
915
+ height = y1 - y0
916
+
917
+ linewidth = max(self._default_font_size / 4, 1)
918
+
919
+ self.output.ax.add_patch(
920
+ mpl.patches.Rectangle(
921
+ (x0, y0),
922
+ width,
923
+ height,
924
+ fill=False,
925
+ edgecolor=edge_color,
926
+ linewidth=linewidth * self.output.scale,
927
+ alpha=alpha,
928
+ linestyle=line_style,
929
+ )
930
+ )
931
+ return self.output
932
+
933
+ def draw_rotated_box_with_label(
934
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
935
+ ):
936
+ """
937
+ Draw a rotated box with label on its top-left corner.
938
+
939
+ Args:
940
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
941
+ where cnt_x and cnt_y are the center coordinates of the box.
942
+ w and h are the width and height of the box. angle represents how
943
+ many degrees the box is rotated CCW with regard to the 0-degree box.
944
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
945
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
946
+ for full list of formats that are accepted.
947
+ line_style (string): the string to use to create the outline of the boxes.
948
+ label (string): label for rotated box. It will not be rendered when set to None.
949
+
950
+ Returns:
951
+ output (VisImage): image object with box drawn.
952
+ """
953
+ cnt_x, cnt_y, w, h, angle = rotated_box
954
+ area = w * h
955
+ # use thinner lines when the box is small
956
+ linewidth = self._default_font_size / (
957
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
958
+ )
959
+
960
+ theta = angle * math.pi / 180.0
961
+ c = math.cos(theta)
962
+ s = math.sin(theta)
963
+ rect = [
964
+ (-w / 2, h / 2),
965
+ (-w / 2, -h / 2),
966
+ (w / 2, -h / 2),
967
+ (w / 2, h / 2),
968
+ ]
969
+ # x: left->right ; y: top->down
970
+ rotated_rect = [
971
+ (s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect
972
+ ]
973
+ for k in range(4):
974
+ j = (k + 1) % 4
975
+ self.draw_line(
976
+ [rotated_rect[k][0], rotated_rect[j][0]],
977
+ [rotated_rect[k][1], rotated_rect[j][1]],
978
+ color=edge_color,
979
+ linestyle="--" if k == 1 else line_style,
980
+ linewidth=linewidth,
981
+ )
982
+
983
+ if label is not None:
984
+ text_pos = rotated_rect[1] # topleft corner
985
+
986
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
987
+ label_color = self._change_color_brightness(
988
+ edge_color, brightness_factor=0.7
989
+ )
990
+ font_size = (
991
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
992
+ * 0.5
993
+ * self._default_font_size
994
+ )
995
+ self.draw_text(
996
+ label,
997
+ text_pos,
998
+ color=label_color,
999
+ font_size=font_size,
1000
+ rotation=angle,
1001
+ )
1002
+
1003
+ return self.output
1004
+
1005
+ def draw_circle(self, circle_coord, color, radius=3):
1006
+ """
1007
+ Args:
1008
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1009
+ of the center of the circle.
1010
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1011
+ formats that are accepted.
1012
+ radius (int): radius of the circle.
1013
+
1014
+ Returns:
1015
+ output (VisImage): image object with box drawn.
1016
+ """
1017
+ x, y = circle_coord
1018
+ self.output.ax.add_patch(
1019
+ mpl.patches.Circle(circle_coord, radius=radius, fill=False, color=color)
1020
+ )
1021
+ return self.output
1022
+
1023
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1024
+ """
1025
+ Args:
1026
+ x_data (list[int]): a list containing x values of all the points being drawn.
1027
+ Length of list should match the length of y_data.
1028
+ y_data (list[int]): a list containing y values of all the points being drawn.
1029
+ Length of list should match the length of x_data.
1030
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1031
+ formats that are accepted.
1032
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1033
+ for a full list of formats that are accepted.
1034
+ linewidth (float or None): width of the line. When it's None,
1035
+ a default value will be computed and used.
1036
+
1037
+ Returns:
1038
+ output (VisImage): image object with line drawn.
1039
+ """
1040
+ if linewidth is None:
1041
+ linewidth = self._default_font_size / 3
1042
+ linewidth = max(linewidth, 1)
1043
+ self.output.ax.add_line(
1044
+ mpl.lines.Line2D(
1045
+ x_data,
1046
+ y_data,
1047
+ linewidth=linewidth * self.output.scale,
1048
+ color=color,
1049
+ linestyle=linestyle,
1050
+ )
1051
+ )
1052
+ return self.output
1053
+
1054
+ def draw_binary_mask(
1055
+ self,
1056
+ binary_mask,
1057
+ color=None,
1058
+ *,
1059
+ edge_color=None,
1060
+ text=None,
1061
+ alpha=0.5,
1062
+ area_threshold=10,
1063
+ ):
1064
+ """
1065
+ Args:
1066
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1067
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1068
+ type.
1069
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1070
+ formats that are accepted. If None, will pick a random color.
1071
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1072
+ full list of formats that are accepted.
1073
+ text (str): if None, will be drawn on the object
1074
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1075
+ area_threshold (float): a connected component smaller than this area will not be shown.
1076
+
1077
+ Returns:
1078
+ output (VisImage): image object with mask drawn.
1079
+ """
1080
+ if color is None:
1081
+ color = random_color(rgb=True, maximum=1)
1082
+ color = mplc.to_rgb(color)
1083
+
1084
+ has_valid_segment = False
1085
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1086
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1087
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1088
+
1089
+ if not mask.has_holes:
1090
+ # draw polygons for regular masks
1091
+ for segment in mask.polygons:
1092
+ area = mask_util.area(
1093
+ mask_util.frPyObjects([segment], shape2d[0], shape2d[1])
1094
+ )
1095
+ if area < (area_threshold or 0):
1096
+ continue
1097
+ has_valid_segment = True
1098
+ segment = segment.reshape(-1, 2)
1099
+ self.draw_polygon(
1100
+ segment, color=color, edge_color=edge_color, alpha=alpha
1101
+ )
1102
+ else:
1103
+ # TODO: Use Path/PathPatch to draw vector graphics:
1104
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1105
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1106
+ rgba[:, :, :3] = color
1107
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1108
+ has_valid_segment = True
1109
+ self.output.ax.imshow(
1110
+ rgba, extent=(0, self.output.width, self.output.height, 0)
1111
+ )
1112
+
1113
+ if text is not None and has_valid_segment:
1114
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1115
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1116
+ return self.output
1117
+
1118
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1119
+ """
1120
+ Args:
1121
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1122
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1123
+ formats that are accepted. If None, will pick a random color.
1124
+ text (str): if None, will be drawn on the object
1125
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1126
+
1127
+ Returns:
1128
+ output (VisImage): image object with mask drawn.
1129
+ """
1130
+ if color is None:
1131
+ color = random_color(rgb=True, maximum=1)
1132
+ color = mplc.to_rgb(color)
1133
+
1134
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1135
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1136
+ rgba[:, :, :3] = color
1137
+ rgba[:, :, 3] = soft_mask * alpha
1138
+ self.output.ax.imshow(
1139
+ rgba, extent=(0, self.output.width, self.output.height, 0)
1140
+ )
1141
+
1142
+ if text is not None:
1143
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1144
+ binary_mask = (soft_mask > 0.5).astype("uint8")
1145
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1146
+ return self.output
1147
+
1148
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1149
+ """
1150
+ Args:
1151
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1152
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1153
+ formats that are accepted.
1154
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1155
+ full list of formats that are accepted. If not provided, a darker shade
1156
+ of the polygon color will be used instead.
1157
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1158
+
1159
+ Returns:
1160
+ output (VisImage): image object with polygon drawn.
1161
+ """
1162
+ if edge_color is not None:
1163
+ """
1164
+ # make edge color darker than the polygon color
1165
+ if alpha > 0.8:
1166
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1167
+ else:
1168
+ edge_color = color
1169
+ """
1170
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1171
+
1172
+ polygon = mpl.patches.Polygon(
1173
+ segment,
1174
+ fill=True,
1175
+ facecolor=mplc.to_rgb(color) + (alpha,),
1176
+ edgecolor=edge_color,
1177
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1178
+ )
1179
+ self.output.ax.add_patch(polygon)
1180
+ return self.output
1181
+
1182
+ """
1183
+ Internal methods:
1184
+ """
1185
+
1186
+ def _jitter(self, color):
1187
+ """
1188
+ Randomly modifies given color to produce a slightly different color than the color given.
1189
+
1190
+ Args:
1191
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1192
+ picked. The values in the list are in the [0.0, 1.0] range.
1193
+
1194
+ Returns:
1195
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1196
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1197
+ """
1198
+ color = mplc.to_rgb(color)
1199
+ vec = np.random.rand(3)
1200
+ # better to do it in another color space
1201
+ vec = vec / np.linalg.norm(vec) * 0.5
1202
+ res = np.clip(vec + color, 0, 1)
1203
+ return tuple(res)
1204
+
1205
+ def _create_grayscale_image(self, mask=None):
1206
+ """
1207
+ Create a grayscale version of the original image.
1208
+ The colors in masked area, if given, will be kept.
1209
+ """
1210
+ img_bw = self.img.astype("f4").mean(axis=2)
1211
+ img_bw = np.stack([img_bw] * 3, axis=2)
1212
+ if mask is not None:
1213
+ img_bw[mask] = self.img[mask]
1214
+ return img_bw
1215
+
1216
+ def _change_color_brightness(self, color, brightness_factor):
1217
+ """
1218
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1219
+ less or more saturation than the original color.
1220
+
1221
+ Args:
1222
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1223
+ formats that are accepted.
1224
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1225
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1226
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1227
+
1228
+ Returns:
1229
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1230
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1231
+ """
1232
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1233
+ color = mplc.to_rgb(color)
1234
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1235
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1236
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1237
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1238
+ modified_color = colorsys.hls_to_rgb(
1239
+ polygon_color[0], modified_lightness, polygon_color[2]
1240
+ )
1241
+ return modified_color
1242
+
1243
+ def _convert_masks(self, masks_or_polygons):
1244
+ """
1245
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1246
+
1247
+ Returns:
1248
+ list[GenericMask]:
1249
+ """
1250
+
1251
+ m = masks_or_polygons
1252
+ if isinstance(m, torch.Tensor):
1253
+ m = m.numpy()
1254
+ ret = []
1255
+ for x in m:
1256
+ if isinstance(x, GenericMask):
1257
+ ret.append(x)
1258
+ else:
1259
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1260
+ return ret
1261
+
1262
+ def _draw_text_in_mask(self, binary_mask, text, color):
1263
+ """
1264
+ Find proper places to draw text given a binary mask.
1265
+ """
1266
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1267
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(
1268
+ binary_mask, 8
1269
+ )
1270
+ if stats[1:, -1].size == 0:
1271
+ return
1272
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1273
+
1274
+ # draw text on the largest component, as well as other very large components.
1275
+ for cid in range(1, _num_cc):
1276
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1277
+ # median is more stable than centroid
1278
+ # center = centroids[largest_component_id]
1279
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1280
+ self.draw_text(text, center, color=color)
1281
+
1282
+ def get_output(self):
1283
+ """
1284
+ Returns:
1285
+ output (VisImage): the image output containing the visualizations added
1286
+ to the image.
1287
+ """
1288
+ return self.output
Comp2Comp-main/comp2comp/visualization/dicom.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pydicom
6
+ from PIL import Image
7
+ from pydicom.dataset import Dataset, FileMetaDataset
8
+ from pydicom.uid import ExplicitVRLittleEndian
9
+
10
+
11
+ def to_dicom(input, output_path, plane="axial"):
12
+ """Converts a png image to a dicom image. Written with assistance from ChatGPT."""
13
+ if isinstance(input, str) or isinstance(input, Path):
14
+ png_path = input
15
+ dicom_path = os.path.join(
16
+ output_path, os.path.basename(png_path).replace(".png", ".dcm")
17
+ )
18
+ image = Image.open(png_path)
19
+ image_array = np.array(image)
20
+ image_array = image_array[:, :, :3]
21
+ else:
22
+ image_array = input
23
+ dicom_path = output_path
24
+
25
+ meta = FileMetaDataset()
26
+ meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.7"
27
+ meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
28
+ meta.TransferSyntaxUID = ExplicitVRLittleEndian
29
+ meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID
30
+
31
+ ds = Dataset()
32
+ ds.file_meta = meta
33
+ ds.is_little_endian = True
34
+ ds.is_implicit_VR = False
35
+ ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.7"
36
+ ds.SOPInstanceUID = pydicom.uid.generate_uid()
37
+ ds.PatientName = "John Doe"
38
+ ds.PatientID = "123456"
39
+ ds.Modality = "OT"
40
+ ds.SeriesInstanceUID = pydicom.uid.generate_uid()
41
+ ds.StudyInstanceUID = pydicom.uid.generate_uid()
42
+ ds.FrameOfReferenceUID = pydicom.uid.generate_uid()
43
+ ds.BitsAllocated = 8
44
+ ds.BitsStored = 8
45
+ ds.HighBit = 7
46
+ ds.PhotometricInterpretation = "RGB"
47
+ ds.PixelRepresentation = 0
48
+ ds.Rows = image_array.shape[0]
49
+ ds.Columns = image_array.shape[1]
50
+ ds.SamplesPerPixel = 3
51
+ ds.PlanarConfiguration = 0
52
+
53
+ if plane.lower() == "axial":
54
+ ds.ImageOrientationPatient = [1, 0, 0, 0, 1, 0]
55
+ elif plane.lower() == "sagittal":
56
+ ds.ImageOrientationPatient = [0, 1, 0, 0, 0, -1]
57
+ elif plane.lower() == "coronal":
58
+ ds.ImageOrientationPatient = [1, 0, 0, 0, 0, -1]
59
+ else:
60
+ raise ValueError(
61
+ "Invalid plane value. Must be 'axial', 'sagittal', or 'coronal'."
62
+ )
63
+
64
+ ds.PixelData = image_array.tobytes()
65
+ pydicom.filewriter.write_file(dicom_path, ds, write_like_original=False)
66
+
67
+
68
+ # Example usage
69
+ if __name__ == "__main__":
70
+ png_path = "../../figures/spine_example.png"
71
+ output_path = "./"
72
+ plane = "sagittal"
73
+ to_dicom(png_path, output_path, plane)
Comp2Comp-main/comp2comp/visualization/linear_planar_reformation.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: louisblankemeier
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ def linear_planar_reformation(
9
+ medical_volume: np.ndarray, segmentation: np.ndarray, centroids, dimension="axial"
10
+ ):
11
+ if dimension == "sagittal" or dimension == "coronal":
12
+ centroids = sorted(centroids, key=lambda x: x[2])
13
+ elif dimension == "axial":
14
+ centroids = sorted(centroids, key=lambda x: x[0])
15
+
16
+ centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids]
17
+ sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))]
18
+ coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))]
19
+ axial_centroids = [centroids[i][2] for i in range(0, len(centroids))]
20
+
21
+ sagittal_vals, coronal_vals, axial_vals = [], [], []
22
+
23
+ if dimension == "sagittal":
24
+ sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0]
25
+
26
+ if dimension == "coronal":
27
+ coronal_vals = [coronal_centroids[0]] * axial_centroids[0]
28
+
29
+ if dimension == "axial":
30
+ axial_vals = [axial_centroids[0]] * sagittal_centroids[0]
31
+
32
+ for i in range(1, len(axial_centroids)):
33
+ if dimension == "sagittal" or dimension == "coronal":
34
+ num = axial_centroids[i] - axial_centroids[i - 1]
35
+ elif dimension == "axial":
36
+ num = sagittal_centroids[i] - sagittal_centroids[i - 1]
37
+
38
+ if dimension == "sagittal":
39
+ interp = list(
40
+ np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num)
41
+ )
42
+ sagittal_vals.extend(interp)
43
+
44
+ if dimension == "coronal":
45
+ interp = list(
46
+ np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num)
47
+ )
48
+ coronal_vals.extend(interp)
49
+
50
+ if dimension == "axial":
51
+ interp = list(
52
+ np.linspace(axial_centroids[i - 1], axial_centroids[i], num=num)
53
+ )
54
+ axial_vals.extend(interp)
55
+
56
+ if dimension == "sagittal":
57
+ sagittal_vals.extend(
58
+ [sagittal_centroids[-1]] * (medical_volume.shape[2] - len(sagittal_vals))
59
+ )
60
+ sagittal_vals = np.array(sagittal_vals)
61
+ sagittal_vals = sagittal_vals.astype(int)
62
+
63
+ if dimension == "coronal":
64
+ coronal_vals.extend(
65
+ [coronal_centroids[-1]] * (medical_volume.shape[2] - len(coronal_vals))
66
+ )
67
+ coronal_vals = np.array(coronal_vals)
68
+ coronal_vals = coronal_vals.astype(int)
69
+
70
+ if dimension == "axial":
71
+ axial_vals.extend(
72
+ [axial_centroids[-1]] * (medical_volume.shape[0] - len(axial_vals))
73
+ )
74
+ axial_vals = np.array(axial_vals)
75
+ axial_vals = axial_vals.astype(int)
76
+
77
+ if dimension == "sagittal":
78
+ sagittal_image = medical_volume[sagittal_vals, :, range(len(sagittal_vals))]
79
+ sagittal_label = segmentation[sagittal_vals, :, range(len(sagittal_vals))]
80
+
81
+ if dimension == "coronal":
82
+ coronal_image = medical_volume[:, coronal_vals, range(len(coronal_vals))]
83
+ coronal_label = segmentation[:, coronal_vals, range(len(coronal_vals))]
84
+
85
+ if dimension == "axial":
86
+ axial_image = medical_volume[range(len(axial_vals)), :, axial_vals]
87
+ axial_label = segmentation[range(len(axial_vals)), :, axial_vals]
88
+
89
+ if dimension == "sagittal":
90
+ return sagittal_image, sagittal_label
91
+
92
+ if dimension == "coronal":
93
+ return coronal_image, coronal_label
94
+
95
+ if dimension == "axial":
96
+ return axial_image, axial_label
Comp2Comp-main/docs/Local Implementation @ M1 arm64 Silicon.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local Implementation @ M1/arm64/AppleSilicon
2
+
3
+ Due to dependencies and differences in architecture, the direct installation of *Comp2Comp* using install.sh or setup.py did not work on an local machine with arm64 / apple silicon running MacOS. This guide is mainly based on [issue #30](https://github.com/StanfordMIMI/Comp2Comp/issues/30). Most of the problems I encountered are caused by requiring TensorFlow and PyTorch in the same environment, which (especially for TensorFlow) is tricky at some times. Thus, this guide focuses more on the setup of the environment @arm64 / AppleSilicon, than *Comp2Comp* or *TotalSegmentator* itself.
4
+
5
+ ## Installation
6
+ Comp2Comp requires TensorFlow and TotalSegmentator requires PyTorch. Although (at the moment) neither *Comp2Comp* nor *TotalSegmentator* can make use of the M1 GPU. Thus, using the arm64-specific versions is necessary.
7
+
8
+ ### TensorFlow
9
+ For reference:
10
+ - https://developer.apple.com/metal/tensorflow-plugin/
11
+ - https://developer.apple.com/forums/thread/683757
12
+ - https://developer.apple.com/forums/thread/686926?page=2
13
+
14
+ 1. Create an environment (python 3.8 or 3.9) using miniforge: https://github.com/conda-forge/miniforge. (TensorFlow did not work for others using anaconda; maybe you can get it running using -c apple and -c conda-forge for the further steps. However, I am not sure whether just the channel (and the retrieved packages) or anaconda's python itself is the problem.)
15
+
16
+ 2. Install TensorFlow and tensorflow-metal in these versions:
17
+ ```
18
+ conda install -c apple tensorflow-deps=2.9.0 -y
19
+ python -m pip install tensorflow-macos==2.9
20
+ python -m pip install tensorflow-metal==0.5.0
21
+ ```
22
+ If you use other methods to install tensorflow, version 2.11.0 might be the best option. Tensorflow version 2.12.0 has caused some problems.
23
+
24
+ ### PyTorch
25
+ For reference https://pytorch.org. The nightly build is (at least for -c conda-forge or -c pytorch) not needed, and the default already supports GPU acceleration on arm64.
26
+
27
+ 3. Install Pytorch
28
+ ```
29
+ conda install pytorch torchvision torchaudio -c pytorch
30
+ ```
31
+
32
+ ### Other Dependencies (Numpy and scikit-learn)
33
+ 4. Install other packages
34
+ ```
35
+ conda install -c conda-forge numpy scikit-learn -y
36
+ ```
37
+
38
+ ### TotalSegmentator
39
+ Louis et al. modified the original *TotalSegmentator* (https://github.com/wasserth/TotalSegmentator) for the use with *Comp2Comp*. *Comp2Comp* does not work with the original version. With the current version of the modified *TotalSegmentator* (https://github.com/StanfordMIMI/TotalSegmentator), no adaptions are necessary.
40
+
41
+ ### Comp2Comp
42
+ For *Comp2Comp* on M1 however, it is important **not** to use bin/install.sh, as some of the predefined requirements won't work. Thus:
43
+
44
+ 5. Clone *Comp2Comp*
45
+ ```
46
+ git clone https://github.com/StanfordMIMI/Comp2Comp.git
47
+ ```
48
+
49
+ 6. Modify setup.py by
50
+ - remove `"numpy==1.23.5"`
51
+ - remove `"tensorflow>=2.0.0"`
52
+
53
+ (You have installed these manually before.)
54
+
55
+ 7. Install *Comp2Comp* with
56
+ ```
57
+ python -m pip install -e .
58
+ ```
59
+
60
+ ## Performance
61
+ Using M1Max w/ 64GB RAM
62
+ - `process 2d` (Comp2Comp in predefined slices): 250 slices in 14.2sec / 361 slices in 17.9sec
63
+ - `process 3d` (segmentation of spine and identification of slices using TotalSegmentator, Comp2Comp in identified slices): high res, full body scan, 1367sec
64
+
65
+ ## ToDos / Nice2Have / Future
66
+ - Integration and use `--fast` and `--body_seg` for TotalSegmentator might be preferable
67
+ - TotalSegmentator works only with CUDA compatible GPUs (!="mps"). I am not sure, about `torch.device("mps")` in the future, see also https://github.com/wasserth/TotalSegmentator/issues/39. Currently, only the CPU is used.
Comp2Comp-main/docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line, and also
5
+ # from the environment for the first two.
6
+ SPHINXOPTS ?=
7
+ SPHINXBUILD ?= sphinx-build
8
+ SOURCEDIR = source
9
+ BUILDDIR = build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Comp2Comp-main/docs/make.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=sphinx-build
9
+ )
10
+ set SOURCEDIR=source
11
+ set BUILDDIR=build
12
+
13
+ %SPHINXBUILD% >NUL 2>NUL
14
+ if errorlevel 9009 (
15
+ echo.
16
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17
+ echo.installed, then set the SPHINXBUILD environment variable to point
18
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
19
+ echo.may add the Sphinx directory to PATH.
20
+ echo.
21
+ echo.If you don't have Sphinx installed, grab it from
22
+ echo.https://www.sphinx-doc.org/
23
+ exit /b 1
24
+ )
25
+
26
+ if "%1" == "" goto help
27
+
28
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29
+ goto end
30
+
31
+ :help
32
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33
+
34
+ :end
35
+ popd
Comp2Comp-main/docs/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ sphinx
2
+ sphinx-rtd-theme
3
+ recommonmark
4
+ sphinx_bootstrap_theme
5
+ sphinxcontrib-bibtex>=2.0.0
6
+ m2r2