This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. README.md +5 -4
  2. SegVol_v1.pth +3 -0
  3. __pycache__/utils.cpython-39.pyc +0 -0
  4. app.py +308 -0
  5. model/LICENSE +21 -0
  6. model/README.md +74 -0
  7. model/__pycache__/inference_cpu.cpython-39.pyc +0 -0
  8. model/asset/FLARE22_Tr_0002_0000.nii.gz +3 -0
  9. model/asset/FLARE22_Tr_0005_0000.nii.gz +3 -0
  10. model/asset/FLARE22_Tr_0034_0000.nii.gz +3 -0
  11. model/asset/FLARE22_Tr_0045_0000.nii.gz +3 -0
  12. model/asset/model.png +0 -0
  13. model/asset/overview back.png +0 -0
  14. model/asset/overview.png +0 -0
  15. model/config/clip/config.json +157 -0
  16. model/config/clip/special_tokens_map.json +1 -0
  17. model/config/clip/tokenizer.json +0 -0
  18. model/config/clip/tokenizer_config.json +1 -0
  19. model/config/clip/vocab.json +0 -0
  20. model/config/config_demo.json +8 -0
  21. model/data_process/__pycache__/demo_data_process.cpython-39.pyc +0 -0
  22. model/data_process/demo_data_process.py +91 -0
  23. model/inference_cpu.py +173 -0
  24. model/inference_demo.py +219 -0
  25. model/network/__pycache__/model.cpython-39.pyc +0 -0
  26. model/network/model.py +91 -0
  27. model/script/inference_demo.sh +8 -0
  28. model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py +172 -0
  29. model/segment_anything_volumetric/__init__.py +12 -0
  30. model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc +0 -0
  31. model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc +0 -0
  32. model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
  33. model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc +0 -0
  34. model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc +0 -0
  35. model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc +0 -0
  36. model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc +0 -0
  37. model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc +0 -0
  38. model/segment_anything_volumetric/automatic_mask_generator.py +372 -0
  39. model/segment_anything_volumetric/build_sam.py +111 -0
  40. model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py +709 -0
  41. model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py +232 -0
  42. model/segment_anything_volumetric/modeling/__init__.py +11 -0
  43. model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  44. model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc +0 -0
  45. model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc +0 -0
  46. model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc +0 -0
  47. model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
  48. model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc +0 -0
  49. model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc +0 -0
  50. model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: SegVol
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.29.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: SegVol
3
+ emoji: 🏢
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: streamlit
7
+ sdk_version: 1.28.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SegVol_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
3
+ size 1660308695
__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.85 kB). View file
 
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ from streamlit_image_coordinates import streamlit_image_coordinates
4
+
5
+
6
+ from model.data_process.demo_data_process import process_ct_gt
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image, ImageDraw
10
+ import monai.transforms as transforms
11
+ from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
12
+
13
+ print('script run')
14
+
15
+ #############################################
16
+ # init session_state
17
+ if 'option' not in st.session_state:
18
+ st.session_state.option = None
19
+ if 'text_prompt' not in st.session_state:
20
+ st.session_state.text_prompt = None
21
+
22
+ if 'reset_demo_case' not in st.session_state:
23
+ st.session_state.reset_demo_case = False
24
+
25
+ if 'preds_3D' not in st.session_state:
26
+ st.session_state.preds_3D = None
27
+
28
+ if 'data_item' not in st.session_state:
29
+ st.session_state.data_item = None
30
+
31
+ if 'points' not in st.session_state:
32
+ st.session_state.points = []
33
+
34
+ if 'use_text_prompt' not in st.session_state:
35
+ st.session_state.use_text_prompt = False
36
+
37
+ if 'use_point_prompt' not in st.session_state:
38
+ st.session_state.use_point_prompt = False
39
+
40
+ if 'use_box_prompt' not in st.session_state:
41
+ st.session_state.use_box_prompt = False
42
+
43
+ if 'rectangle_3Dbox' not in st.session_state:
44
+ st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
45
+
46
+ if 'irregular_box' not in st.session_state:
47
+ st.session_state.irregular_box = False
48
+
49
+ if 'running' not in st.session_state:
50
+ st.session_state.running = False
51
+
52
+ if 'transparency' not in st.session_state:
53
+ st.session_state.transparency = 0.25
54
+
55
+ case_list = [
56
+ 'model/asset/FLARE22_Tr_0002_0000.nii.gz',
57
+ 'model/asset/FLARE22_Tr_0005_0000.nii.gz',
58
+ 'model/asset/FLARE22_Tr_0034_0000.nii.gz',
59
+ 'model/asset/FLARE22_Tr_0045_0000.nii.gz'
60
+ ]
61
+
62
+ #############################################
63
+
64
+ #############################################
65
+ # reset functions
66
+ def clear_prompts():
67
+ st.session_state.points = []
68
+ st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
69
+
70
+ def reset_demo_case():
71
+ st.session_state.data_item = None
72
+ st.session_state.reset_demo_case = True
73
+ clear_prompts()
74
+
75
+ def clear_file():
76
+ st.session_state.option = None
77
+ process_ct_gt.clear()
78
+ reset_demo_case()
79
+ clear_prompts()
80
+
81
+ #############################################
82
+
83
+ st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
84
+
85
+ github_col, arxive_col = st.columns(2)
86
+
87
+ with github_col:
88
+ st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
89
+
90
+ with arxive_col:
91
+ st.write('Paper:https://arxiv.org/abs/2311.13385')
92
+
93
+
94
+ # modify demo case here
95
+ demo_type = st.radio(
96
+ "Demo case source",
97
+ ["Select", "Upload"],
98
+ on_change=clear_file
99
+ )
100
+
101
+ if demo_type=="Select":
102
+ uploaded_file = st.selectbox(
103
+ "Select a demo case",
104
+ case_list,
105
+ index=None,
106
+ placeholder="Select a demo case...",
107
+ on_change=reset_demo_case
108
+ )
109
+ else:
110
+ uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
111
+
112
+ st.session_state.option = uploaded_file
113
+
114
+ if st.session_state.option is not None and \
115
+ st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
116
+
117
+ st.session_state.data_item = process_ct_gt(st.session_state.option)
118
+ st.session_state.reset_demo_case = False
119
+ st.session_state.preds_3D = None
120
+
121
+ prompt_col1, prompt_col2 = st.columns(2)
122
+
123
+ with prompt_col1:
124
+ st.session_state.use_text_prompt = st.toggle('Sematic prompt')
125
+ text_prompt_type = st.radio(
126
+ "Sematic prompt type",
127
+ ["Predefined", "Custom"],
128
+ disabled=(not st.session_state.use_text_prompt)
129
+ )
130
+ if text_prompt_type == "Predefined":
131
+ pre_text = st.selectbox(
132
+ "Predefined anatomical category:",
133
+ ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
134
+ index=None,
135
+ disabled=(not st.session_state.use_text_prompt)
136
+ )
137
+ else:
138
+ pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
139
+ disabled=(not st.session_state.use_text_prompt))
140
+ if pre_text is None or len(pre_text) > 0:
141
+ st.session_state.text_prompt = pre_text
142
+ else:
143
+ st.session_state.text_prompt = None
144
+
145
+
146
+ with prompt_col2:
147
+ spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
148
+ spatial_prompt = st.radio(
149
+ "Spatial prompt type",
150
+ ["Point prompt", "Box prompt"],
151
+ on_change=clear_prompts,
152
+ disabled=(not spatial_prompt_on))
153
+
154
+ if spatial_prompt == "Point prompt":
155
+ st.session_state.use_point_prompt = True
156
+ st.session_state.use_box_prompt = False
157
+ elif spatial_prompt == "Box prompt":
158
+ st.session_state.use_box_prompt = True
159
+ st.session_state.use_point_prompt = False
160
+ else:
161
+ st.session_state.use_point_prompt = False
162
+ st.session_state.use_box_prompt = False
163
+
164
+ if not spatial_prompt_on:
165
+ st.session_state.use_point_prompt = False
166
+ st.session_state.use_box_prompt = False
167
+
168
+ if not st.session_state.use_text_prompt:
169
+ st.session_state.text_prompt = None
170
+
171
+ if st.session_state.option is None:
172
+ st.write('please select demo case first')
173
+ else:
174
+ image_3D = st.session_state.data_item['z_image'][0].numpy()
175
+ col_control1, col_control2 = st.columns(2)
176
+
177
+ with col_control1:
178
+ selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
179
+
180
+ with col_control2:
181
+ selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
182
+ if st.session_state.use_box_prompt:
183
+ top, bottom = st.select_slider(
184
+ 'Top and bottom of box',
185
+ options=range(0, 325),
186
+ value=(0, 324),
187
+ disabled=st.session_state.running
188
+ )
189
+ st.session_state.rectangle_3Dbox[0] = top
190
+ st.session_state.rectangle_3Dbox[3] = bottom
191
+ col_image1, col_image2 = st.columns(2)
192
+
193
+ if st.session_state.preds_3D is not None:
194
+ st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
195
+
196
+ with col_image1:
197
+
198
+ image_z_array = image_3D[selected_index_z]
199
+
200
+ preds_z_array = None
201
+ if st.session_state.preds_3D is not None:
202
+ preds_z_array = st.session_state.preds_3D[selected_index_z]
203
+
204
+ image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
205
+
206
+
207
+ if st.session_state.use_point_prompt:
208
+ value_xy = streamlit_image_coordinates(image_z, width=325)
209
+
210
+ if value_xy is not None:
211
+ point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
212
+ if len(st.session_state.points) >= 3:
213
+ st.warning('Max point num is 3', icon="⚠️")
214
+ elif point_ax_xy not in st.session_state.points:
215
+ st.session_state.points.append(point_ax_xy)
216
+ print('point_ax_xy add rerun')
217
+ st.rerun()
218
+ elif st.session_state.use_box_prompt:
219
+ canvas_result_xy = st_canvas(
220
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
221
+ stroke_width=3,
222
+ stroke_color='#2909F1',
223
+ background_image=image_z,
224
+ update_streamlit=True,
225
+ height=325,
226
+ width=325,
227
+ drawing_mode='transform',
228
+ point_display_radius=0,
229
+ key="canvas_xy",
230
+ initial_drawing=initial_rectangle,
231
+ display_toolbar=True
232
+ )
233
+ try:
234
+ print(canvas_result_xy.json_data['objects'][0]['angle'])
235
+ if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
236
+ st.warning('Rotating is undefined behavior', icon="⚠️")
237
+ st.session_state.irregular_box = True
238
+ else:
239
+ st.session_state.irregular_box = False
240
+ reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
241
+ except:
242
+ print('exception')
243
+ pass
244
+ else:
245
+ st.image(image_z, use_column_width=False)
246
+
247
+ with col_image2:
248
+ image_y_array = image_3D[:, selected_index_y, :]
249
+
250
+ preds_y_array = None
251
+ if st.session_state.preds_3D is not None:
252
+ preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
253
+
254
+ image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
255
+
256
+ if st.session_state.use_point_prompt:
257
+ value_yz = streamlit_image_coordinates(image_y, width=325)
258
+
259
+ if value_yz is not None:
260
+ point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
261
+ if len(st.session_state.points) >= 3:
262
+ st.warning('Max point num is 3', icon="⚠️")
263
+ elif point_ax_xz not in st.session_state.points:
264
+ st.session_state.points.append(point_ax_xz)
265
+ print('point_ax_xz add rerun')
266
+ st.rerun()
267
+ elif st.session_state.use_box_prompt:
268
+ if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
269
+ draw = ImageDraw.Draw(image_y)
270
+ #rectangle xz view (upper-left and lower-right)
271
+ rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
272
+ (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
273
+ # Draw the rectangle on the image
274
+ draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
275
+ st.image(image_y, use_column_width=False)
276
+ else:
277
+ st.image(image_y, use_column_width=False)
278
+
279
+
280
+ col1, col2, col3 = st.columns(3)
281
+
282
+ with col1:
283
+ if st.button("Clear", use_container_width=True,
284
+ disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
285
+ clear_prompts()
286
+ st.session_state.preds_3D = None
287
+ st.rerun()
288
+
289
+ with col3:
290
+ run_button_name = 'Run'if not st.session_state.running else 'Running'
291
+ if st.button(run_button_name, type="primary", use_container_width=True,
292
+ disabled=(
293
+ st.session_state.data_item is None or
294
+ (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
295
+ st.session_state.irregular_box or
296
+ st.session_state.running
297
+ )):
298
+ st.session_state.running = True
299
+ st.rerun()
300
+
301
+ # if len(st.session_state.points) > 0:
302
+ # st.write(st.session_state.points)
303
+
304
+ if st.session_state.running:
305
+ st.session_state.running = False
306
+ with st.status("Running...", expanded=False) as status:
307
+ run()
308
+ st.rerun()
model/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 BAAI-DCAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
model/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SegVol: Universal and Interactive Volumetric Medical Image Segmentation
2
+ This repo is the official implementation of [SegVol: Universal and Interactive Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.13385).
3
+
4
+ ## News🚀
5
+ (2023.11.24) *You can download weight files of SegVol and ViT(CTs pre-train) [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link).* 🔥
6
+
7
+ (2023.11.23) *The brief introduction and instruction have been uploaded.*
8
+
9
+ (2023.11.23) *The inference demo code has been uploaded.*
10
+
11
+ (2023.11.22) *The first edition of our paper has been uploaded to arXiv.* 📃
12
+
13
+ ## Introduction
14
+ <img src="https://github.com/BAAI-DCAI/SegVol/blob/main/asset/overview.png" width="60%" height="60%">
15
+
16
+ The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
17
+
18
+ We will release SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k CTs).
19
+
20
+ ## Usage
21
+ ### Requirements
22
+ The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or higher virsion) is needed first. Following install key requirements using commands:
23
+
24
+ ```
25
+ pip install 'monai[all]==0.9.0'
26
+ pip install einops==0.6.1
27
+ pip install transformers==4.18.0
28
+ pip install matplotlib
29
+ ```
30
+ ### Config and run demo script
31
+ 1. You can download the demo case [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link), or download the whole demo dataset [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
32
+ 2. Please set CT path and Ground Truth path of the case in the [config_demo.json](https://github.com/BAAI-DCAI/SegVol/blob/main/config/config_demo.json).
33
+ 3. After that, config the [inference_demo.sh](https://github.com/BAAI-DCAI/SegVol/blob/main/script/inference_demo.sh) file for execution:
34
+
35
+ - `$segvol_ckpt`: the path of SegVol's checkpoint (Download from [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link)).
36
+
37
+ - `$work_dir`: any path of folder you want to save the log files and visualizaion results.
38
+
39
+ 4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** [here](https://github.com/BAAI-DCAI/SegVol/blob/35f3ff9c943a74f630e6948051a1fe21aaba91bc/inference_demo.py#L208C11-L208C11).
40
+ 5. Now, just run `bash script/inference_demo.sh` to infer your demo case.
41
+
42
+ ## Citation
43
+ If you find this repository helpful, please consider citing:
44
+ ```
45
+ @misc{du2023segvol,
46
+ title={SegVol: Universal and Interactive Volumetric Medical Image Segmentation},
47
+ author={Yuxin Du and Fan Bai and Tiejun Huang and Bo Zhao},
48
+ year={2023},
49
+ eprint={2311.13385},
50
+ archivePrefix={arXiv},
51
+ primaryClass={cs.CV}
52
+ }
53
+ ```
54
+
55
+ ## Acknowledgement
56
+ Thanks for the following amazing works:
57
+
58
+ [HuggingFace](https://huggingface.co/).
59
+
60
+ [CLIP](https://github.com/openai/CLIP).
61
+
62
+ [MONAI](https://github.com/Project-MONAI/MONAI).
63
+
64
+ [Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.
65
+
66
+ [Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.
67
+
68
+ [Image by pch.vector](https://www.freepik.com/free-vector/different-phone-hand-gestures-set_9649376.htm#query=Vector%20touch%20screen%20hand%20gestures&position=4&from_view=search&track=ais) on Freepik.
69
+
70
+ [Image by starline](https://www.freepik.com/free-vector/set-three-light-bulb-represent-effective-business-idea-concept_37588597.htm#query=idea&position=0&from_view=search&track=sph) on Freepik.
71
+
72
+
73
+
74
+
model/__pycache__/inference_cpu.cpython-39.pyc ADDED
Binary file (4.67 kB). View file
 
model/asset/FLARE22_Tr_0002_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb16eced003524fa005e28b2822c0b53503f1223d758cdf72528fad359aa10ba
3
+ size 30611274
model/asset/FLARE22_Tr_0005_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2be5019bfc7e805d5e24785bcd44ffe7720e13e38b2a3124ad25b454811b221c
3
+ size 26615527
model/asset/FLARE22_Tr_0034_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:023c5d06ea2a6c8866c1e214ecee06a4447a8d0c50225142cdfdbbccc2bf8c66
3
+ size 28821917
model/asset/FLARE22_Tr_0045_0000.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:336b3719af673fd6fafe89d7d5d95d5f18239a9faccde9753703fc1465f43736
3
+ size 32885093
model/asset/model.png ADDED
model/asset/overview back.png ADDED
model/asset/overview.png ADDED
model/config/clip/config.json ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-base-patch32",
3
+ "architectures": [
4
+ "CLIPModel"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 2.6592,
8
+ "model_type": "clip",
9
+ "projection_dim": 512,
10
+ "text_config": {
11
+ "_name_or_path": "",
12
+ "add_cross_attention": false,
13
+ "architectures": null,
14
+ "attention_dropout": 0.0,
15
+ "bad_words_ids": null,
16
+ "bos_token_id": 0,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "dropout": 0.0,
23
+ "early_stopping": false,
24
+ "encoder_no_repeat_ngram_size": 0,
25
+ "eos_token_id": 2,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "quick_gelu",
30
+ "hidden_size": 512,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "initializer_factor": 1.0,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 2048,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 77,
48
+ "min_length": 0,
49
+ "model_type": "clip_text_model",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 8,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 12,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_scores": false,
59
+ "pad_token_id": 1,
60
+ "prefix": null,
61
+ "projection_dim": 512,
62
+ "problem_type": null,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": null,
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.16.0.dev0",
79
+ "use_bfloat16": false,
80
+ "vocab_size": 49408
81
+ },
82
+ "text_config_dict": null,
83
+ "transformers_version": null,
84
+ "vision_config": {
85
+ "_name_or_path": "",
86
+ "add_cross_attention": false,
87
+ "architectures": null,
88
+ "attention_dropout": 0.0,
89
+ "bad_words_ids": null,
90
+ "bos_token_id": null,
91
+ "chunk_size_feed_forward": 0,
92
+ "cross_attention_hidden_size": null,
93
+ "decoder_start_token_id": null,
94
+ "diversity_penalty": 0.0,
95
+ "do_sample": false,
96
+ "dropout": 0.0,
97
+ "early_stopping": false,
98
+ "encoder_no_repeat_ngram_size": 0,
99
+ "eos_token_id": null,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "hidden_act": "quick_gelu",
104
+ "hidden_size": 768,
105
+ "id2label": {
106
+ "0": "LABEL_0",
107
+ "1": "LABEL_1"
108
+ },
109
+ "image_size": 224,
110
+ "initializer_factor": 1.0,
111
+ "initializer_range": 0.02,
112
+ "intermediate_size": 3072,
113
+ "is_decoder": false,
114
+ "is_encoder_decoder": false,
115
+ "label2id": {
116
+ "LABEL_0": 0,
117
+ "LABEL_1": 1
118
+ },
119
+ "layer_norm_eps": 1e-05,
120
+ "length_penalty": 1.0,
121
+ "max_length": 20,
122
+ "min_length": 0,
123
+ "model_type": "clip_vision_model",
124
+ "no_repeat_ngram_size": 0,
125
+ "num_attention_heads": 12,
126
+ "num_beam_groups": 1,
127
+ "num_beams": 1,
128
+ "num_hidden_layers": 12,
129
+ "num_return_sequences": 1,
130
+ "output_attentions": false,
131
+ "output_hidden_states": false,
132
+ "output_scores": false,
133
+ "pad_token_id": null,
134
+ "patch_size": 32,
135
+ "prefix": null,
136
+ "projection_dim" : 512,
137
+ "problem_type": null,
138
+ "pruned_heads": {},
139
+ "remove_invalid_values": false,
140
+ "repetition_penalty": 1.0,
141
+ "return_dict": true,
142
+ "return_dict_in_generate": false,
143
+ "sep_token_id": null,
144
+ "task_specific_params": null,
145
+ "temperature": 1.0,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": true,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": null,
152
+ "torchscript": false,
153
+ "transformers_version": "4.16.0.dev0",
154
+ "use_bfloat16": false
155
+ },
156
+ "vision_config_dict": null
157
+ }
model/config/clip/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
model/config/clip/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/config/clip/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
model/config/clip/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
model/config/config_demo.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_name": "AbdomenCT-1k",
3
+ "categories": ["liver", "kidney", "spleen", "pancreas"],
4
+ "demo_case": {
5
+ "ct_path": "path/to/Case_image",
6
+ "gt_path": "path/to/Case_label"
7
+ }
8
+ }
model/data_process/__pycache__/demo_data_process.cpython-39.pyc ADDED
Binary file (3.26 kB). View file
 
model/data_process/demo_data_process.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import monai.transforms as transforms
3
+ import streamlit as st
4
+ import tempfile
5
+
6
+ class MinMaxNormalization(transforms.Transform):
7
+ def __call__(self, data):
8
+ d = dict(data)
9
+ k = "image"
10
+ d[k] = d[k] - d[k].min()
11
+ d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
12
+ return d
13
+
14
+ class DimTranspose(transforms.Transform):
15
+ def __init__(self, keys):
16
+ self.keys = keys
17
+
18
+ def __call__(self, data):
19
+ d = dict(data)
20
+ for key in self.keys:
21
+ d[key] = np.swapaxes(d[key], -1, -3)
22
+ return d
23
+
24
+ class ForegroundNormalization(transforms.Transform):
25
+ def __init__(self, keys):
26
+ self.keys = keys
27
+
28
+ def __call__(self, data):
29
+ d = dict(data)
30
+
31
+ for key in self.keys:
32
+ d[key] = self.normalize(d[key])
33
+ return d
34
+
35
+ def normalize(self, ct_narray):
36
+ ct_voxel_ndarray = ct_narray.copy()
37
+ ct_voxel_ndarray = ct_voxel_ndarray.flatten()
38
+ thred = np.mean(ct_voxel_ndarray)
39
+ voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
40
+ upper_bound = np.percentile(voxel_filtered, 99.95)
41
+ lower_bound = np.percentile(voxel_filtered, 00.05)
42
+ mean = np.mean(voxel_filtered)
43
+ std = np.std(voxel_filtered)
44
+ ### transform ###
45
+ ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
46
+ ct_narray = (ct_narray - mean) / max(std, 1e-8)
47
+ return ct_narray
48
+
49
+ @st.cache_data
50
+ def process_ct_gt(case_path, spatial_size=(32,256,256)):
51
+ if case_path is None:
52
+ return None
53
+ print('Data preprocessing...')
54
+ # transform
55
+ img_loader = transforms.LoadImage(dtype=np.float32)
56
+ transform = transforms.Compose(
57
+ [
58
+ transforms.Orientationd(keys=["image"], axcodes="RAS"),
59
+ ForegroundNormalization(keys=["image"]),
60
+ DimTranspose(keys=["image"]),
61
+ MinMaxNormalization(),
62
+ transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
63
+ transforms.CropForegroundd(keys=["image"], source_key="image"),
64
+ transforms.ToTensord(keys=["image"]),
65
+ ]
66
+ )
67
+ zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
68
+ z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
69
+ ###
70
+ item = {}
71
+ # generate ct_voxel_ndarray
72
+ if type(case_path) is str:
73
+ ct_voxel_ndarray, _ = img_loader(case_path)
74
+ else:
75
+ bytes_data = case_path.read()
76
+ with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
77
+ tmp.write(bytes_data)
78
+ tmp.seek(0)
79
+ ct_voxel_ndarray, _ = img_loader(tmp.name)
80
+ ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
81
+ ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
82
+ item['image'] = ct_voxel_ndarray
83
+
84
+ # transform
85
+ item = transform(item)
86
+ item_zoom_out = zoom_out_transform(item)
87
+ item['zoom_out_image'] = item_zoom_out['image']
88
+
89
+ item_z = z_transform(item)
90
+ item['z_image'] = item_z['image']
91
+ return item
model/inference_cpu.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import json
6
+ import monai.transforms as transforms
7
+
8
+ from model.segment_anything_volumetric import sam_model_registry
9
+ from model.network.model import SegVol
10
+ from model.data_process.demo_data_process import process_ct_gt
11
+ from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
12
+ from model.utils.visualize import draw_result
13
+ import streamlit as st
14
+
15
+ def set_parse():
16
+ # %% set up parser
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--test_mode", default=True, type=bool)
19
+ parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth')
20
+ parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
21
+ parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
22
+ parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
23
+ parser.add_argument('-work_dir', type=str, default='./work_dir')
24
+ ### demo
25
+ parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
30
+ image_single_resize = image_resize
31
+ image_single = image[0,0]
32
+ ori_shape = image_single.shape
33
+ resize_shape = image_single_resize.shape[2:]
34
+
35
+ # generate prompts
36
+ text_single = None if text_prompt is None else [text_prompt]
37
+ points_single = None
38
+ box_single = None
39
+
40
+ if args.use_point_prompt:
41
+ point, point_label = point_prompt
42
+ points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
43
+ binary_points_resize = build_binary_points(point, point_label, resize_shape)
44
+ if args.use_box_prompt:
45
+ box_single = box_prompt.unsqueeze(0).float()
46
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
47
+
48
+ ####################
49
+ # zoom-out inference:
50
+ print('--- zoom out inference ---')
51
+ print(text_single)
52
+ print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
53
+ with torch.no_grad():
54
+ logits_global_single = segvol_model(image_single_resize,
55
+ text=text_single,
56
+ boxes=box_single,
57
+ points=points_single)
58
+
59
+ # resize back global logits
60
+ logits_global_single = F.interpolate(
61
+ logits_global_single.cpu(),
62
+ size=ori_shape, mode='nearest')[0][0]
63
+
64
+ # build prompt reflection for zoom-in
65
+ if args.use_point_prompt:
66
+ binary_points = F.interpolate(
67
+ binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
68
+ size=ori_shape, mode='nearest')[0][0]
69
+ if args.use_box_prompt:
70
+ binary_cube = F.interpolate(
71
+ binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
72
+ size=ori_shape, mode='nearest')[0][0]
73
+ # draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
74
+ if not args.use_zoom_in:
75
+ return logits_global_single
76
+
77
+ ####################
78
+ # zoom-in inference:
79
+ min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
80
+ if min_d is None:
81
+ print('Fail to detect foreground!')
82
+ return logits_global_single
83
+
84
+ # Crop roi
85
+ image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
86
+ global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
87
+
88
+ assert not (args.use_box_prompt and args.use_point_prompt)
89
+ # label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
90
+ prompt_reflection = None
91
+ if args.use_box_prompt:
92
+ binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
93
+ prompt_reflection = (
94
+ binary_cube_cropped.unsqueeze(0).unsqueeze(0),
95
+ global_preds.unsqueeze(0).unsqueeze(0)
96
+ )
97
+ if args.use_point_prompt:
98
+ binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
99
+ prompt_reflection = (
100
+ binary_points_cropped.unsqueeze(0).unsqueeze(0),
101
+ global_preds.unsqueeze(0).unsqueeze(0)
102
+ )
103
+
104
+ ## inference
105
+ with torch.no_grad():
106
+ logits_single_cropped = sliding_window_inference(
107
+ image_single_cropped, prompt_reflection,
108
+ args.spatial_size, 1, segvol_model, args.infer_overlap,
109
+ text=text_single,
110
+ use_box=args.use_box_prompt,
111
+ use_point=args.use_point_prompt,
112
+ logits_global_single=logits_global_single,
113
+ )
114
+ logits_single_cropped = logits_single_cropped.cpu().squeeze()
115
+ if logits_single_cropped.shape != logits_global_single.shape:
116
+ logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
117
+
118
+ return logits_global_single
119
+
120
+ @st.cache_resource
121
+ def build_model():
122
+ # build model
123
+ st.write('building model')
124
+ clip_ckpt = 'model/config/clip'
125
+ resume = 'SegVol_v1.pth'
126
+ sam_model = sam_model_registry['vit']()
127
+ segvol_model = SegVol(
128
+ image_encoder=sam_model.image_encoder,
129
+ mask_decoder=sam_model.mask_decoder,
130
+ prompt_encoder=sam_model.prompt_encoder,
131
+ clip_ckpt=clip_ckpt,
132
+ roi_size=(32,256,256),
133
+ patch_size=(4,16,16),
134
+ test_mode=True,
135
+ )
136
+ segvol_model = torch.nn.DataParallel(segvol_model)
137
+ segvol_model.eval()
138
+ # load param
139
+ if os.path.isfile(resume):
140
+ ## Map model to be loaded to specified single GPU
141
+ loc = 'cpu'
142
+ checkpoint = torch.load(resume, map_location=loc)
143
+ segvol_model.load_state_dict(checkpoint['model'], strict=True)
144
+ print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
145
+ print('model build done!')
146
+ return segvol_model
147
+
148
+ @st.cache_data
149
+ def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
150
+ # seg config
151
+ args = set_parse()
152
+ args.use_zoom_in = True
153
+ args.use_text_prompt = text_prompt is not None
154
+ args.use_box_prompt = _box_prompt is not None
155
+ args.use_point_prompt = _point_prompt is not None
156
+
157
+ segvol_model = build_model()
158
+
159
+ # run inference
160
+ logits = zoom_in_zoom_out(
161
+ args, segvol_model,
162
+ _image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
163
+ text_prompt, _point_prompt, _box_prompt)
164
+ print(logits.shape)
165
+ resize_transform = transforms.Compose([
166
+ transforms.AddChannel(),
167
+ transforms.Resize((325,325,325), mode='trilinear')
168
+ ]
169
+ )
170
+ logits = resize_transform(logits)[0]
171
+ print(logits.shape)
172
+ return (torch.sigmoid(logits) > 0.5).int().numpy()
173
+
model/inference_demo.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import json
6
+ from segment_anything_volumetric import sam_model_registry
7
+ from network.model import SegVol
8
+ from data_process.demo_data_process import process_ct_gt
9
+ import monai.transforms as transforms
10
+ from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
11
+ from utils.visualize import draw_result
12
+
13
+ def set_parse():
14
+ # %% set up parser
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--test_mode", default=True, type=bool)
17
+ parser.add_argument("--resume", type = str, default = '')
18
+ parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
19
+ parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
20
+ parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
21
+ parser.add_argument('-work_dir', type=str, default='./work_dir')
22
+ ### demo
23
+ parser.add_argument('--demo_config', type=str, required=True)
24
+ parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
25
+ args = parser.parse_args()
26
+ return args
27
+
28
+ def dice_score(preds, labels): # on GPU
29
+ assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
30
+ predict = preds.view(1, -1)
31
+ target = labels.view(1, -1)
32
+ if target.shape[1] < 1e8:
33
+ predict = predict.cuda()
34
+ target = target.cuda()
35
+ predict = torch.sigmoid(predict)
36
+ predict = torch.where(predict > 0.5, 1., 0.)
37
+
38
+ tp = torch.sum(torch.mul(predict, target))
39
+ den = torch.sum(predict) + torch.sum(target) + 1
40
+ dice = 2 * tp / den
41
+
42
+ if target.shape[1] < 1e8:
43
+ predict = predict.cpu()
44
+ target = target.cpu()
45
+ return dice
46
+
47
+ def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
48
+ logits_labels_record = {}
49
+ image_single_resize = image_resize
50
+ image_single = image[0,0]
51
+ ori_shape = image_single.shape
52
+ for item_idx in range(len(categories)):
53
+ # get label to generate prompts
54
+ label_single = gt3D[0][item_idx]
55
+ label_single_resize = gt3D_resize[0][item_idx]
56
+ # skip meaningless categories
57
+ if torch.sum(label_single) == 0:
58
+ print('No object, skip')
59
+ continue
60
+ # generate prompts
61
+ text_single = categories[item_idx] if args.use_text_prompt else None
62
+ if categories is not None: print(f'inference |{categories[item_idx]}| target...')
63
+ points_single = None
64
+ box_single = None
65
+ if args.use_point_prompt:
66
+ point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
67
+ points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
68
+ binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
69
+ if args.use_box_prompt:
70
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
71
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
72
+
73
+ ####################
74
+ # zoom-out inference:
75
+ print('--- zoom out inference ---')
76
+ print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
77
+ with torch.no_grad():
78
+ logits_global_single = segvol_model(image_single_resize.cuda(),
79
+ text=text_single,
80
+ boxes=box_single,
81
+ points=points_single)
82
+
83
+ # resize back global logits
84
+ logits_global_single = F.interpolate(
85
+ logits_global_single.cpu(),
86
+ size=ori_shape, mode='nearest')[0][0]
87
+
88
+ # build prompt reflection for zoom-in
89
+ if args.use_point_prompt:
90
+ binary_points = F.interpolate(
91
+ binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
92
+ size=ori_shape, mode='nearest')[0][0]
93
+ if args.use_box_prompt:
94
+ binary_cube = F.interpolate(
95
+ binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
96
+ size=ori_shape, mode='nearest')[0][0]
97
+ zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
98
+ logits_labels_record[categories[item_idx]] = (
99
+ zoom_out_dice,
100
+ image_single,
101
+ points_single,
102
+ box_single,
103
+ logits_global_single,
104
+ label_single)
105
+ print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
106
+ if not args.use_zoom_in:
107
+ continue
108
+
109
+ ####################
110
+ # zoom-in inference:
111
+ min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
112
+ if min_d is None:
113
+ print('Fail to detect foreground!')
114
+ continue
115
+
116
+ # Crop roi
117
+ image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
118
+ global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
119
+
120
+ assert not (args.use_box_prompt and args.use_point_prompt)
121
+ # label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
122
+ prompt_reflection = None
123
+ if args.use_box_prompt:
124
+ binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
125
+ prompt_reflection = (
126
+ binary_cube_cropped.unsqueeze(0).unsqueeze(0),
127
+ global_preds.unsqueeze(0).unsqueeze(0)
128
+ )
129
+ if args.use_point_prompt:
130
+ binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
131
+ prompt_reflection = (
132
+ binary_points_cropped.unsqueeze(0).unsqueeze(0),
133
+ global_preds.unsqueeze(0).unsqueeze(0)
134
+ )
135
+
136
+ ## inference
137
+ with torch.no_grad():
138
+ logits_single_cropped = sliding_window_inference(
139
+ image_single_cropped.cuda(), prompt_reflection,
140
+ args.spatial_size, 1, segvol_model, args.infer_overlap,
141
+ text=text_single,
142
+ use_box=args.use_box_prompt,
143
+ use_point=args.use_point_prompt,
144
+ )
145
+ logits_single_cropped = logits_single_cropped.cpu().squeeze()
146
+ logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
147
+ zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
148
+ logits_labels_record[categories[item_idx]] = (
149
+ zoom_in_dice,
150
+ image_single,
151
+ points_single,
152
+ box_single,
153
+ logits_global_single,
154
+ label_single)
155
+ print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
156
+ return logits_labels_record
157
+
158
+ def inference_single_ct(args, segvol_model, data_item, categories):
159
+ segvol_model.eval()
160
+ image, gt3D = data_item["image"].float(), data_item["label"]
161
+ image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
162
+
163
+ logits_labels_record = zoom_in_zoom_out(
164
+ args, segvol_model,
165
+ image.unsqueeze(0), image_zoom_out.unsqueeze(0),
166
+ gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
167
+ categories=categories)
168
+
169
+ # visualize
170
+ if args.visualize:
171
+ for target, values in logits_labels_record.items():
172
+ dice_score, image, point_prompt, box_prompt, logits, labels = values
173
+ print(f'{target} result with Dice score {dice_score:.4f} visualizing')
174
+ draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
175
+
176
+ def main(args):
177
+ gpu = 0
178
+ torch.cuda.set_device(gpu)
179
+ # build model
180
+ sam_model = sam_model_registry['vit'](args=args)
181
+ segvol_model = SegVol(
182
+ image_encoder=sam_model.image_encoder,
183
+ mask_decoder=sam_model.mask_decoder,
184
+ prompt_encoder=sam_model.prompt_encoder,
185
+ clip_ckpt=args.clip_ckpt,
186
+ roi_size=args.spatial_size,
187
+ patch_size=args.patch_size,
188
+ test_mode=args.test_mode,
189
+ ).cuda()
190
+ segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
191
+
192
+ # load param
193
+ if os.path.isfile(args.resume):
194
+ ## Map model to be loaded to specified single GPU
195
+ loc = 'cuda:{}'.format(gpu)
196
+ checkpoint = torch.load(args.resume, map_location=loc)
197
+ segvol_model.load_state_dict(checkpoint['model'], strict=True)
198
+ print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
199
+
200
+ # load demo config
201
+ with open(args.demo_config, 'r') as file:
202
+ config_dict = json.load(file)
203
+ ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
204
+
205
+ # preprocess for data
206
+ data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
207
+
208
+ # seg config for prompt & zoom-in-zoom-out
209
+ args.use_zoom_in = True
210
+ args.use_text_prompt = True
211
+ args.use_box_prompt = True
212
+ args.use_point_prompt = False
213
+ args.visualize = False
214
+
215
+ inference_single_ct(args, segvol_model, data_item, categories)
216
+
217
+ if __name__ == "__main__":
218
+ args = set_parse()
219
+ main(args)
model/network/__pycache__/model.cpython-39.pyc ADDED
Binary file (3.28 kB). View file
 
model/network/model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
6
+
7
+ #%% set up model
8
+ class SegVol(nn.Module):
9
+ def __init__(self,
10
+ image_encoder,
11
+ mask_decoder,
12
+ prompt_encoder,
13
+ clip_ckpt,
14
+ roi_size,
15
+ patch_size,
16
+ test_mode=False,
17
+ ):
18
+ super().__init__()
19
+ self.image_encoder = image_encoder
20
+ self.mask_decoder = mask_decoder
21
+ self.prompt_encoder = prompt_encoder
22
+ self.text_encoder = TextEncoder(clip_ckpt)
23
+ self.feat_shape = np.array(roi_size)/np.array(patch_size)
24
+ self.test_mode = test_mode
25
+
26
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
27
+ bs = image.shape[0]
28
+ img_shape = (image.shape[2], image.shape[3], image.shape[4])
29
+ image_embedding, _ = self.image_encoder(image)
30
+ image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
31
+ int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
32
+ # test mode
33
+ if self.test_mode:
34
+ return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
35
+ # train mode
36
+ # future release
37
+
38
+ def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
39
+ with torch.no_grad():
40
+ if boxes is not None:
41
+ if len(boxes.shape) == 2:
42
+ boxes = boxes[:, None, :] # (B, 1, 6)
43
+ if text is not None:
44
+ text_embedding = self.text_encoder(text) # (B, 768)
45
+ else:
46
+ text_embedding = None
47
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
48
+ points=points,
49
+ boxes=boxes,
50
+ masks=None,
51
+ text_embedding=text_embedding,
52
+ )
53
+
54
+ dense_pe = self.prompt_encoder.get_dense_pe()
55
+ low_res_masks, _ = self.mask_decoder(
56
+ image_embeddings=image_embedding,
57
+ text_embedding = text_embedding,
58
+ image_pe=dense_pe,
59
+ sparse_prompt_embeddings=sparse_embeddings,
60
+ dense_prompt_embeddings=dense_embeddings,
61
+ multimask_output=False,
62
+ )
63
+ logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
64
+ return logits
65
+
66
+ class TextEncoder(nn.Module):
67
+ def __init__(self, clip_ckpt):
68
+ super().__init__()
69
+ config = CLIPTextConfig()
70
+ self.clip_text_model = CLIPTextModel(config)
71
+ self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
72
+ self.dim_align = nn.Linear(512, 768)
73
+ # freeze text encoder
74
+ for param in self.clip_text_model.parameters():
75
+ param.requires_grad = False
76
+
77
+ def organ2tokens(self, organ_names):
78
+ text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
79
+ tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
80
+ return tokens
81
+
82
+ def forward(self, text):
83
+ if text is None:
84
+ return None
85
+ if type(text) is str:
86
+ text = [text]
87
+ tokens = self.organ2tokens(text)
88
+ clip_outputs = self.clip_text_model(**tokens)
89
+ text_embedding = clip_outputs.pooler_output
90
+ text_embedding = self.dim_align(text_embedding)
91
+ return text_embedding
model/script/inference_demo.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ export segvol_ckpt="path/to/SegVol_v1.pth"
2
+ export work_dir="path/to/work_dir"
3
+ export demo_config_path="./config/config_demo.json"
4
+
5
+ CUDA_VISIBLE_DEVICES=0 python inference_demo.py \
6
+ --resume $segvol_ckpt \
7
+ -work_dir $work_dir \
8
+ --demo_config $demo_config_path
model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from functools import partial
7
+ from pathlib import Path
8
+ import urllib.request
9
+ import torch
10
+
11
+ from .modeling import (
12
+ ImageEncoderViT,
13
+ MaskDecoder,
14
+ PromptEncoder,
15
+ Sam,
16
+ TwoWayTransformer,
17
+ )
18
+
19
+ from .modeling.image_encoder_swin import SwinTransformer
20
+
21
+ from monai.utils import ensure_tuple_rep, optional_import
22
+
23
+ def build_sam_vit_h(checkpoint=None, image_size=1024):
24
+ return _build_sam(
25
+ encoder_embed_dim=1280,
26
+ encoder_depth=32,
27
+ encoder_num_heads=16,
28
+ encoder_global_attn_indexes=[7, 15, 23, 31],
29
+ checkpoint=checkpoint,
30
+ image_size=image_size,
31
+ )
32
+
33
+
34
+ build_sam = build_sam_vit_h
35
+
36
+
37
+ def build_sam_vit_l(checkpoint=None, image_size=1024):
38
+ return _build_sam(
39
+ encoder_embed_dim=1024,
40
+ encoder_depth=24,
41
+ encoder_num_heads=16,
42
+ encoder_global_attn_indexes=[5, 11, 17, 23],
43
+ checkpoint=checkpoint,
44
+ image_size=image_size,
45
+ )
46
+
47
+
48
+ def build_sam_vit_b(checkpoint=None, image_size=1024):
49
+ return _build_sam(
50
+ encoder_embed_dim=768,
51
+ encoder_depth=12,
52
+ encoder_num_heads=12,
53
+ encoder_global_attn_indexes=[2, 5, 8, 11],
54
+ checkpoint=checkpoint,
55
+ image_size=image_size,
56
+ )
57
+ """
58
+ Examples::
59
+ # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
60
+ >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
61
+ # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
62
+ >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
63
+ # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
64
+ >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
65
+ """
66
+
67
+ def build_sam_vit_swin(checkpoint=None, image_size=96):
68
+ print('==> build_sam_vit_swin')
69
+ return _build_sam(
70
+ encoder_embed_dim=48,
71
+ encoder_depth=12,
72
+ encoder_num_heads=12,
73
+ encoder_global_attn_indexes=[2, 5, 8, 11],
74
+ checkpoint=checkpoint,
75
+ image_size=image_size,
76
+ )
77
+
78
+ sam_model_registry = {
79
+ "default": build_sam_vit_h,
80
+ "vit_h": build_sam_vit_h,
81
+ "vit_l": build_sam_vit_l,
82
+ "vit_b": build_sam_vit_b,
83
+ "swin_vit": build_sam_vit_swin,
84
+ }
85
+
86
+
87
+ def _build_sam(
88
+ encoder_embed_dim,
89
+ encoder_depth,
90
+ encoder_num_heads,
91
+ encoder_global_attn_indexes,
92
+ checkpoint=None,
93
+ image_size=None,
94
+ spatial_dims=3,
95
+ ):
96
+ prompt_embed_dim = 768
97
+ patch_size = ensure_tuple_rep(2, spatial_dims)
98
+ window_size = ensure_tuple_rep(7, spatial_dims)
99
+ image_embedding_size = [size // 32 for size in image_size]
100
+ sam = Sam(
101
+ image_encoder=SwinTransformer(
102
+ in_chans=1,
103
+ embed_dim=encoder_embed_dim,
104
+ window_size=window_size,
105
+ patch_size=patch_size,
106
+ depths=(2, 2, 6, 2), #(2, 2, 6, 2),
107
+ num_heads=(3, 6, 12, 24),
108
+ mlp_ratio=4.0,
109
+ qkv_bias=True,
110
+ spatial_dims=spatial_dims,
111
+ ),
112
+ prompt_encoder=PromptEncoder(
113
+ embed_dim=prompt_embed_dim,
114
+ image_embedding_size=image_embedding_size,
115
+ input_image_size=image_size,
116
+ mask_in_chans=16,
117
+ ),
118
+ mask_decoder=MaskDecoder(
119
+ num_multimask_outputs=3,
120
+ transformer=TwoWayTransformer(
121
+ depth=2,
122
+ embedding_dim=prompt_embed_dim,
123
+ mlp_dim=2048,
124
+ num_heads=8,
125
+ ),
126
+ transformer_dim=prompt_embed_dim,
127
+ iou_head_depth=3,
128
+ iou_head_hidden_dim=256,
129
+ ),
130
+ pixel_mean=[123.675, 116.28, 103.53],
131
+ pixel_std=[58.395, 57.12, 57.375],
132
+ )
133
+ sam.eval()
134
+ if checkpoint is not None:
135
+ checkpoint = Path(checkpoint)
136
+ if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
137
+ cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
138
+ if len(cmd) == 0 or cmd.lower() == 'y':
139
+ checkpoint.parent.mkdir(parents=True, exist_ok=True)
140
+ print("Downloading SAM ViT-B checkpoint...")
141
+ urllib.request.urlretrieve(
142
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
143
+ checkpoint,
144
+ )
145
+ print(checkpoint.name, " is downloaded!")
146
+ elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
147
+ cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
148
+ if len(cmd) == 0 or cmd.lower() == 'y':
149
+ checkpoint.parent.mkdir(parents=True, exist_ok=True)
150
+ print("Downloading SAM ViT-H checkpoint...")
151
+ urllib.request.urlretrieve(
152
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
153
+ checkpoint,
154
+ )
155
+ print(checkpoint.name, " is downloaded!")
156
+ elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
157
+ cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
158
+ if len(cmd) == 0 or cmd.lower() == 'y':
159
+ checkpoint.parent.mkdir(parents=True, exist_ok=True)
160
+ print("Downloading SAM ViT-L checkpoint...")
161
+ urllib.request.urlretrieve(
162
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
163
+ checkpoint,
164
+ )
165
+ print(checkpoint.name, " is downloaded!")
166
+
167
+
168
+ if checkpoint is not None:
169
+ with open(checkpoint, "rb") as f:
170
+ state_dict = torch.load(f)
171
+ sam.load_state_dict(state_dict)
172
+ return sam
model/segment_anything_volumetric/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam_vit_3d,
9
+ sam_model_registry,
10
+ )
11
+ from .predictor import SamPredictor
12
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (407 Bytes). View file
 
model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (377 Bytes). View file
 
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc ADDED
Binary file (2.62 kB). View file
 
model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc ADDED
Binary file (9.96 kB). View file
 
model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc ADDED
Binary file (9.98 kB). View file
 
model/segment_anything_volumetric/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ ) -> None:
53
+ """
54
+ Using a SAM model, generates masks for the entire image.
55
+ Generates a grid of point prompts over the image, then filters
56
+ low quality and duplicate masks. The default settings are chosen
57
+ for SAM with a ViT-H backbone.
58
+
59
+ Arguments:
60
+ model (Sam): The SAM model to use for mask prediction.
61
+ points_per_side (int or None): The number of points to be sampled
62
+ along one side of the image. The total number of points is
63
+ points_per_side**2. If None, 'point_grids' must provide explicit
64
+ point sampling.
65
+ points_per_batch (int): Sets the number of points run simultaneously
66
+ by the model. Higher numbers may be faster but use more GPU memory.
67
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
+ model's predicted mask quality.
69
+ stability_score_thresh (float): A filtering threshold in [0,1], using
70
+ the stability of the mask under changes to the cutoff used to binarize
71
+ the model's mask predictions.
72
+ stability_score_offset (float): The amount to shift the cutoff when
73
+ calculated the stability score.
74
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
+ suppression to filter duplicate masks.
76
+ crop_n_layers (int): If >0, mask prediction will be run again on
77
+ crops of the image. Sets the number of layers to run, where each
78
+ layer has 2**i_layer number of image crops.
79
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks between different crops.
81
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
+ In the first crop layer, crops will overlap by this fraction of
83
+ the image length. Later layers with more crops scale down this overlap.
84
+ crop_n_points_downscale_factor (int): The number of points-per-side
85
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
+ point_grids (list(np.ndarray) or None): A list over explicit grids
87
+ of points used for sampling, normalized to [0,1]. The nth grid in the
88
+ list is used in the nth crop layer. Exclusive with points_per_side.
89
+ min_mask_region_area (int): If >0, postprocessing will be applied
90
+ to remove disconnected regions and holes in masks with area smaller
91
+ than min_mask_region_area. Requires opencv.
92
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
+ For large resolutions, 'binary_mask' may consume large amounts of
95
+ memory.
96
+ """
97
+
98
+ assert (points_per_side is None) != (
99
+ point_grids is None
100
+ ), "Exactly one of points_per_side or point_grid must be provided."
101
+ if points_per_side is not None:
102
+ self.point_grids = build_all_layer_point_grids(
103
+ points_per_side,
104
+ crop_n_layers,
105
+ crop_n_points_downscale_factor,
106
+ )
107
+ elif point_grids is not None:
108
+ self.point_grids = point_grids
109
+ else:
110
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
111
+
112
+ assert output_mode in [
113
+ "binary_mask",
114
+ "uncompressed_rle",
115
+ "coco_rle",
116
+ ], f"Unknown output_mode {output_mode}."
117
+ if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
+
120
+ if min_mask_region_area > 0:
121
+ import cv2 # type: ignore # noqa: F401
122
+
123
+ self.predictor = SamPredictor(model)
124
+ self.points_per_batch = points_per_batch
125
+ self.pred_iou_thresh = pred_iou_thresh
126
+ self.stability_score_thresh = stability_score_thresh
127
+ self.stability_score_offset = stability_score_offset
128
+ self.box_nms_thresh = box_nms_thresh
129
+ self.crop_n_layers = crop_n_layers
130
+ self.crop_nms_thresh = crop_nms_thresh
131
+ self.crop_overlap_ratio = crop_overlap_ratio
132
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
+ self.min_mask_region_area = min_mask_region_area
134
+ self.output_mode = output_mode
135
+
136
+ @torch.no_grad()
137
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
+ """
139
+ Generates masks for the given image.
140
+
141
+ Arguments:
142
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
+
144
+ Returns:
145
+ list(dict(str, any)): A list over records for masks. Each record is
146
+ a dict containing the following keys:
147
+ segmentation (dict(str, any) or np.ndarray): The mask. If
148
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
149
+ is a dictionary containing the RLE.
150
+ bbox (list(float)): The box around the mask, in XYWH format.
151
+ area (int): The area in pixels of the mask.
152
+ predicted_iou (float): The model's own prediction of the mask's
153
+ quality. This is filtered by the pred_iou_thresh parameter.
154
+ point_coords (list(list(float))): The point coordinates input
155
+ to the model to generate this mask.
156
+ stability_score (float): A measure of the mask's quality. This
157
+ is filtered on using the stability_score_thresh parameter.
158
+ crop_box (list(float)): The crop of the image used to generate
159
+ the mask, given in XYWH format.
160
+ """
161
+
162
+ # Generate masks
163
+ mask_data = self._generate_masks(image)
164
+
165
+ # Filter small disconnected regions and holes in masks
166
+ if self.min_mask_region_area > 0:
167
+ mask_data = self.postprocess_small_regions(
168
+ mask_data,
169
+ self.min_mask_region_area,
170
+ max(self.box_nms_thresh, self.crop_nms_thresh),
171
+ )
172
+
173
+ # Encode masks
174
+ if self.output_mode == "coco_rle":
175
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
+ elif self.output_mode == "binary_mask":
177
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
+ else:
179
+ mask_data["segmentations"] = mask_data["rles"]
180
+
181
+ # Write mask records
182
+ curr_anns = []
183
+ for idx in range(len(mask_data["segmentations"])):
184
+ ann = {
185
+ "segmentation": mask_data["segmentations"][idx],
186
+ "area": area_from_rle(mask_data["rles"][idx]),
187
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
189
+ "point_coords": [mask_data["points"][idx].tolist()],
190
+ "stability_score": mask_data["stability_score"][idx].item(),
191
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
+ }
193
+ curr_anns.append(ann)
194
+
195
+ return curr_anns
196
+
197
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
198
+ orig_size = image.shape[:2]
199
+ crop_boxes, layer_idxs = generate_crop_boxes(
200
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
+ )
202
+
203
+ # Iterate over image crops
204
+ data = MaskData()
205
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
+ data.cat(crop_data)
208
+
209
+ # Remove duplicate masks between crops
210
+ if len(crop_boxes) > 1:
211
+ # Prefer masks from smaller crops
212
+ scores = 1 / box_area(data["crop_boxes"])
213
+ scores = scores.to(data["boxes"].device)
214
+ keep_by_nms = batched_nms(
215
+ data["boxes"].float(),
216
+ scores,
217
+ torch.zeros_like(data["boxes"][:, 0]), # categories
218
+ iou_threshold=self.crop_nms_thresh,
219
+ )
220
+ data.filter(keep_by_nms)
221
+
222
+ data.to_numpy()
223
+ return data
224
+
225
+ def _process_crop(
226
+ self,
227
+ image: np.ndarray,
228
+ crop_box: List[int],
229
+ crop_layer_idx: int,
230
+ orig_size: Tuple[int, ...],
231
+ ) -> MaskData:
232
+ # Crop the image and calculate embeddings
233
+ x0, y0, x1, y1 = crop_box
234
+ cropped_im = image[y0:y1, x0:x1, :]
235
+ cropped_im_size = cropped_im.shape[:2]
236
+ self.predictor.set_image(cropped_im)
237
+
238
+ # Get points for this crop
239
+ points_scale = np.array(cropped_im_size)[None, ::-1]
240
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
+
242
+ # Generate masks for this crop in batches
243
+ data = MaskData()
244
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
+ data.cat(batch_data)
247
+ del batch_data
248
+ self.predictor.reset_image()
249
+
250
+ # Remove duplicates within this crop.
251
+ keep_by_nms = batched_nms(
252
+ data["boxes"].float(),
253
+ data["iou_preds"],
254
+ torch.zeros_like(data["boxes"][:, 0]), # categories
255
+ iou_threshold=self.box_nms_thresh,
256
+ )
257
+ data.filter(keep_by_nms)
258
+
259
+ # Return to the original image frame
260
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
+ data["points"] = uncrop_points(data["points"], crop_box)
262
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
+
264
+ return data
265
+
266
+ def _process_batch(
267
+ self,
268
+ points: np.ndarray,
269
+ im_size: Tuple[int, ...],
270
+ crop_box: List[int],
271
+ orig_size: Tuple[int, ...],
272
+ ) -> MaskData:
273
+ orig_h, orig_w = orig_size
274
+
275
+ # Run model on this batch
276
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
+ masks, iou_preds, _ = self.predictor.predict_torch(
280
+ in_points[:, None, :],
281
+ in_labels[:, None],
282
+ multimask_output=True,
283
+ return_logits=True,
284
+ )
285
+
286
+ # Serialize predictions and store in MaskData
287
+ data = MaskData(
288
+ masks=masks.flatten(0, 1),
289
+ iou_preds=iou_preds.flatten(0, 1),
290
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
+ )
292
+ del masks
293
+
294
+ # Filter by predicted IoU
295
+ if self.pred_iou_thresh > 0.0:
296
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
+ data.filter(keep_mask)
298
+
299
+ # Calculate stability score
300
+ data["stability_score"] = calculate_stability_score(
301
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
+ )
303
+ if self.stability_score_thresh > 0.0:
304
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Threshold masks and calculate boxes
308
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
+ data["boxes"] = batched_mask_to_box(data["masks"])
310
+
311
+ # Filter boxes that touch crop boundaries
312
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
model/segment_anything_volumetric/build_sam.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from functools import partial
7
+ from pathlib import Path
8
+ import urllib.request
9
+ import torch
10
+
11
+ from .modeling import (
12
+ ImageEncoderViT,
13
+ MaskDecoder,
14
+ PromptEncoder,
15
+ Sam,
16
+ TwoWayTransformer,
17
+ )
18
+ import numpy as np
19
+ from .modeling.image_encoder_swin import SwinTransformer
20
+ from monai.networks.nets import ViT
21
+ from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
22
+
23
+ from monai.utils import ensure_tuple_rep, optional_import
24
+
25
+
26
+ """
27
+ Examples::
28
+ # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
29
+ >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
30
+ # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
31
+ >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
32
+ # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
33
+ >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
34
+ """
35
+
36
+ def build_sam_vit_3d(checkpoint=None):
37
+ print('build_sam_vit_3d...')
38
+ return _build_sam(
39
+ image_encoder_type='vit',
40
+ embed_dim = 768,
41
+ patch_size=[4,16,16],
42
+ checkpoint=checkpoint,
43
+ image_size=[32,256,256],
44
+ )
45
+
46
+ sam_model_registry = {
47
+ "vit": build_sam_vit_3d,
48
+ }
49
+
50
+
51
+ def _build_sam(
52
+ image_encoder_type,
53
+ embed_dim,
54
+ patch_size,
55
+ checkpoint,
56
+ image_size,
57
+ ):
58
+ mlp_dim = 3072
59
+ num_layers = 12
60
+ num_heads = 12
61
+ pos_embed = 'perceptron'
62
+ dropout_rate = 0.0
63
+
64
+ image_encoder=ViT(
65
+ in_channels=1,
66
+ img_size=image_size,
67
+ patch_size=patch_size,
68
+ hidden_size=embed_dim,
69
+ mlp_dim=mlp_dim,
70
+ num_layers=num_layers,
71
+ num_heads=num_heads,
72
+ pos_embed=pos_embed,
73
+ classification=False,
74
+ dropout_rate=dropout_rate,
75
+ )
76
+ image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
77
+
78
+ if checkpoint is not None:
79
+ with open(checkpoint, "rb") as f:
80
+ state_dict = torch.load(f, map_location='cpu')['state_dict']
81
+ encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
82
+ image_encoder.load_state_dict(encoder_dict)
83
+ print(f'===> image_encoder.load_param: {checkpoint}')
84
+ sam = Sam(
85
+ image_encoder=image_encoder,
86
+ prompt_encoder=PromptEncoder(
87
+ embed_dim=embed_dim,
88
+ image_embedding_size=image_embedding_size,
89
+ input_image_size=image_size,
90
+ mask_in_chans=16,
91
+ ),
92
+ mask_decoder=MaskDecoder(
93
+ image_encoder_type=image_encoder_type,
94
+ num_multimask_outputs=3,
95
+ transformer=TwoWayTransformer(
96
+ depth=2,
97
+ embedding_dim=embed_dim,
98
+ mlp_dim=2048,
99
+ num_heads=8,
100
+ ),
101
+ transformer_dim=embed_dim,
102
+ iou_head_depth=3,
103
+ iou_head_hidden_dim=256,
104
+ image_size=np.array(image_size),
105
+ patch_size=np.array(patch_size),
106
+ ),
107
+ pixel_mean=[123.675, 116.28, 103.53],
108
+ pixel_std=[58.395, 57.12, 57.375],
109
+ )
110
+ sam.eval()
111
+ return sam
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Tuple, Type, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ from torch.nn import LayerNorm
9
+
10
+ from monai.networks.blocks import MLPBlock as Mlp
11
+ from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
12
+ from monai.networks.layers import DropPath, trunc_normal_
13
+ from monai.utils import ensure_tuple_rep, optional_import
14
+
15
+ rearrange, _ = optional_import("einops", name="rearrange")
16
+
17
+ def window_partition(x, window_size):
18
+ """window partition operation based on: "Liu et al.,
19
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
20
+ <https://arxiv.org/abs/2103.14030>"
21
+ https://github.com/microsoft/Swin-Transformer
22
+ Args:
23
+ x: input tensor.
24
+ window_size: local window size.
25
+ """
26
+ x_shape = x.size()
27
+ if len(x_shape) == 5:
28
+ b, d, h, w, c = x_shape
29
+ x = x.view(
30
+ b,
31
+ d // window_size[0],
32
+ window_size[0],
33
+ h // window_size[1],
34
+ window_size[1],
35
+ w // window_size[2],
36
+ window_size[2],
37
+ c,
38
+ )
39
+ windows = (
40
+ x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
41
+ )
42
+ elif len(x_shape) == 4:
43
+ b, h, w, c = x.shape
44
+ x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
45
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
46
+ return windows
47
+
48
+
49
+ def window_reverse(windows, window_size, dims):
50
+ """window reverse operation based on: "Liu et al.,
51
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
52
+ <https://arxiv.org/abs/2103.14030>"
53
+ https://github.com/microsoft/Swin-Transformer
54
+ Args:
55
+ windows: windows tensor.
56
+ window_size: local window size.
57
+ dims: dimension values.
58
+ """
59
+ if len(dims) == 4:
60
+ b, d, h, w = dims
61
+ x = windows.view(
62
+ b,
63
+ d // window_size[0],
64
+ h // window_size[1],
65
+ w // window_size[2],
66
+ window_size[0],
67
+ window_size[1],
68
+ window_size[2],
69
+ -1,
70
+ )
71
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
72
+
73
+ elif len(dims) == 3:
74
+ b, h, w = dims
75
+ x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
76
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
77
+ return x
78
+
79
+
80
+ def get_window_size(x_size, window_size, shift_size=None):
81
+ """Computing window size based on: "Liu et al.,
82
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
83
+ <https://arxiv.org/abs/2103.14030>"
84
+ https://github.com/microsoft/Swin-Transformer
85
+ Args:
86
+ x_size: input size.
87
+ window_size: local window size.
88
+ shift_size: window shifting size.
89
+ """
90
+
91
+ use_window_size = list(window_size)
92
+ if shift_size is not None:
93
+ use_shift_size = list(shift_size)
94
+ for i in range(len(x_size)):
95
+ if x_size[i] <= window_size[i]:
96
+ use_window_size[i] = x_size[i]
97
+ if shift_size is not None:
98
+ use_shift_size[i] = 0
99
+
100
+ if shift_size is None:
101
+ return tuple(use_window_size)
102
+ else:
103
+ return tuple(use_window_size), tuple(use_shift_size)
104
+
105
+
106
+ class WindowAttention(nn.Module):
107
+ """
108
+ Window based multi-head self attention module with relative position bias based on: "Liu et al.,
109
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
110
+ <https://arxiv.org/abs/2103.14030>"
111
+ https://github.com/microsoft/Swin-Transformer
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ dim: int,
117
+ num_heads: int,
118
+ window_size: Sequence[int],
119
+ qkv_bias: bool = False,
120
+ attn_drop: float = 0.0,
121
+ proj_drop: float = 0.0,
122
+ ) -> None:
123
+ """
124
+ Args:
125
+ dim: number of feature channels.
126
+ num_heads: number of attention heads.
127
+ window_size: local window size.
128
+ qkv_bias: add a learnable bias to query, key, value.
129
+ attn_drop: attention dropout rate.
130
+ proj_drop: dropout rate of output.
131
+ """
132
+
133
+ super().__init__()
134
+ self.dim = dim
135
+ self.window_size = window_size
136
+ self.num_heads = num_heads
137
+ head_dim = dim // num_heads
138
+ self.scale = head_dim**-0.5
139
+ mesh_args = torch.meshgrid.__kwdefaults__
140
+
141
+ if len(self.window_size) == 3:
142
+ self.relative_position_bias_table = nn.Parameter(
143
+ torch.zeros(
144
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
145
+ num_heads,
146
+ )
147
+ )
148
+ coords_d = torch.arange(self.window_size[0])
149
+ coords_h = torch.arange(self.window_size[1])
150
+ coords_w = torch.arange(self.window_size[2])
151
+ if mesh_args is not None:
152
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
153
+ else:
154
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
155
+ coords_flatten = torch.flatten(coords, 1)
156
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
157
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
158
+ relative_coords[:, :, 0] += self.window_size[0] - 1
159
+ relative_coords[:, :, 1] += self.window_size[1] - 1
160
+ relative_coords[:, :, 2] += self.window_size[2] - 1
161
+ relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
162
+ relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
163
+ elif len(self.window_size) == 2:
164
+ self.relative_position_bias_table = nn.Parameter(
165
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
166
+ )
167
+ coords_h = torch.arange(self.window_size[0])
168
+ coords_w = torch.arange(self.window_size[1])
169
+ if mesh_args is not None:
170
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
171
+ else:
172
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w))
173
+ coords_flatten = torch.flatten(coords, 1)
174
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
175
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
176
+ relative_coords[:, :, 0] += self.window_size[0] - 1
177
+ relative_coords[:, :, 1] += self.window_size[1] - 1
178
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
179
+
180
+ relative_position_index = relative_coords.sum(-1)
181
+ self.register_buffer("relative_position_index", relative_position_index)
182
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
183
+ self.attn_drop = nn.Dropout(attn_drop)
184
+ self.proj = nn.Linear(dim, dim)
185
+ self.proj_drop = nn.Dropout(proj_drop)
186
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
187
+ self.softmax = nn.Softmax(dim=-1)
188
+
189
+ def forward(self, x, mask):
190
+ b, n, c = x.shape
191
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
192
+ q, k, v = qkv[0], qkv[1], qkv[2]
193
+ q = q * self.scale
194
+ attn = q @ k.transpose(-2, -1)
195
+ relative_position_bias = self.relative_position_bias_table[
196
+ self.relative_position_index.clone()[:n, :n].reshape(-1)
197
+ ].reshape(n, n, -1)
198
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
199
+ attn = attn + relative_position_bias.unsqueeze(0)
200
+ if mask is not None:
201
+ nw = mask.shape[0]
202
+ attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
203
+ attn = attn.view(-1, self.num_heads, n, n)
204
+ attn = self.softmax(attn)
205
+ else:
206
+ attn = self.softmax(attn)
207
+
208
+ attn = self.attn_drop(attn)
209
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
210
+ x = self.proj(x)
211
+ x = self.proj_drop(x)
212
+ return x
213
+
214
+
215
+ class SwinTransformerBlock(nn.Module):
216
+ """
217
+ Swin Transformer block based on: "Liu et al.,
218
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
219
+ <https://arxiv.org/abs/2103.14030>"
220
+ https://github.com/microsoft/Swin-Transformer
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dim: int,
226
+ num_heads: int,
227
+ window_size: Sequence[int],
228
+ shift_size: Sequence[int],
229
+ mlp_ratio: float = 4.0,
230
+ qkv_bias: bool = True,
231
+ drop: float = 0.0,
232
+ attn_drop: float = 0.0,
233
+ drop_path: float = 0.0,
234
+ act_layer: str = "GELU",
235
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
236
+ use_checkpoint: bool = False,
237
+ ) -> None:
238
+ """
239
+ Args:
240
+ dim: number of feature channels.
241
+ num_heads: number of attention heads.
242
+ window_size: local window size.
243
+ shift_size: window shift size.
244
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
245
+ qkv_bias: add a learnable bias to query, key, value.
246
+ drop: dropout rate.
247
+ attn_drop: attention dropout rate.
248
+ drop_path: stochastic depth rate.
249
+ act_layer: activation layer.
250
+ norm_layer: normalization layer.
251
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
252
+ """
253
+
254
+ super().__init__()
255
+ self.dim = dim
256
+ self.num_heads = num_heads
257
+ self.window_size = window_size
258
+ self.shift_size = shift_size
259
+ self.mlp_ratio = mlp_ratio
260
+ self.use_checkpoint = use_checkpoint
261
+ self.norm1 = norm_layer(dim)
262
+ self.attn = WindowAttention(
263
+ dim,
264
+ window_size=self.window_size,
265
+ num_heads=num_heads,
266
+ qkv_bias=qkv_bias,
267
+ attn_drop=attn_drop,
268
+ proj_drop=drop,
269
+ )
270
+
271
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
272
+ self.norm2 = norm_layer(dim)
273
+ mlp_hidden_dim = int(dim * mlp_ratio)
274
+ self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
275
+
276
+ def forward_part1(self, x, mask_matrix):
277
+ x_shape = x.size()
278
+ x = self.norm1(x)
279
+ if len(x_shape) == 5:
280
+ b, d, h, w, c = x.shape
281
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
282
+ pad_l = pad_t = pad_d0 = 0
283
+ pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
284
+ pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
285
+ pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
286
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
287
+ _, dp, hp, wp, _ = x.shape
288
+ dims = [b, dp, hp, wp]
289
+
290
+ elif len(x_shape) == 4:
291
+ b, h, w, c = x.shape
292
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
293
+ pad_l = pad_t = 0
294
+ pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
295
+ pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
296
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
297
+ _, hp, wp, _ = x.shape
298
+ dims = [b, hp, wp]
299
+
300
+ if any(i > 0 for i in shift_size):
301
+ if len(x_shape) == 5:
302
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
303
+ elif len(x_shape) == 4:
304
+ shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
305
+ attn_mask = mask_matrix
306
+ else:
307
+ shifted_x = x
308
+ attn_mask = None
309
+ x_windows = window_partition(shifted_x, window_size)
310
+ attn_windows = self.attn(x_windows, mask=attn_mask)
311
+ attn_windows = attn_windows.view(-1, *(window_size + (c,)))
312
+ shifted_x = window_reverse(attn_windows, window_size, dims)
313
+ if any(i > 0 for i in shift_size):
314
+ if len(x_shape) == 5:
315
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
316
+ elif len(x_shape) == 4:
317
+ x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
318
+ else:
319
+ x = shifted_x
320
+
321
+ if len(x_shape) == 5:
322
+ if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
323
+ x = x[:, :d, :h, :w, :].contiguous()
324
+ elif len(x_shape) == 4:
325
+ if pad_r > 0 or pad_b > 0:
326
+ x = x[:, :h, :w, :].contiguous()
327
+
328
+ return x
329
+
330
+ def forward_part2(self, x):
331
+ return self.drop_path(self.mlp(self.norm2(x)))
332
+
333
+ def load_from(self, weights, n_block, layer):
334
+ root = f"module.{layer}.0.blocks.{n_block}."
335
+ block_names = [
336
+ "norm1.weight",
337
+ "norm1.bias",
338
+ "attn.relative_position_bias_table",
339
+ "attn.relative_position_index",
340
+ "attn.qkv.weight",
341
+ "attn.qkv.bias",
342
+ "attn.proj.weight",
343
+ "attn.proj.bias",
344
+ "norm2.weight",
345
+ "norm2.bias",
346
+ "mlp.fc1.weight",
347
+ "mlp.fc1.bias",
348
+ "mlp.fc2.weight",
349
+ "mlp.fc2.bias",
350
+ ]
351
+ with torch.no_grad():
352
+ self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
353
+ self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
354
+ self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
355
+ self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
356
+ self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
357
+ self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
358
+ self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
359
+ self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
360
+ self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
361
+ self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
362
+ self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
363
+ self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
364
+ self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
365
+ self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
366
+
367
+ def forward(self, x, mask_matrix):
368
+ shortcut = x
369
+ if self.use_checkpoint:
370
+ x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
371
+ else:
372
+ x = self.forward_part1(x, mask_matrix)
373
+ x = shortcut + self.drop_path(x)
374
+ if self.use_checkpoint:
375
+ x = x + checkpoint.checkpoint(self.forward_part2, x)
376
+ else:
377
+ x = x + self.forward_part2(x)
378
+ return x
379
+
380
+
381
+ class PatchMerging(nn.Module):
382
+ """
383
+ Patch merging layer based on: "Liu et al.,
384
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
385
+ <https://arxiv.org/abs/2103.14030>"
386
+ https://github.com/microsoft/Swin-Transformer
387
+ """
388
+
389
+ def __init__(
390
+ self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
391
+ ) -> None: # type: ignore
392
+ """
393
+ Args:
394
+ dim: number of feature channels.
395
+ norm_layer: normalization layer.
396
+ spatial_dims: number of spatial dims.
397
+ """
398
+
399
+ super().__init__()
400
+ self.dim = dim
401
+ if spatial_dims == 3:
402
+ self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
403
+ self.norm = norm_layer(8 * dim)
404
+ elif spatial_dims == 2:
405
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
406
+ self.norm = norm_layer(4 * dim)
407
+
408
+ def forward(self, x):
409
+
410
+ x_shape = x.size()
411
+ if len(x_shape) == 5:
412
+ b, d, h, w, c = x_shape
413
+ pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
414
+ if pad_input:
415
+ x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
416
+ x0 = x[:, 0::2, 0::2, 0::2, :]
417
+ x1 = x[:, 1::2, 0::2, 0::2, :]
418
+ x2 = x[:, 0::2, 1::2, 0::2, :]
419
+ x3 = x[:, 0::2, 0::2, 1::2, :]
420
+ x4 = x[:, 1::2, 0::2, 1::2, :]
421
+ x5 = x[:, 0::2, 1::2, 0::2, :]
422
+ x6 = x[:, 0::2, 0::2, 1::2, :]
423
+ x7 = x[:, 1::2, 1::2, 1::2, :]
424
+ x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
425
+
426
+ elif len(x_shape) == 4:
427
+ b, h, w, c = x_shape
428
+ pad_input = (h % 2 == 1) or (w % 2 == 1)
429
+ if pad_input:
430
+ x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
431
+ x0 = x[:, 0::2, 0::2, :]
432
+ x1 = x[:, 1::2, 0::2, :]
433
+ x2 = x[:, 0::2, 1::2, :]
434
+ x3 = x[:, 1::2, 1::2, :]
435
+ x = torch.cat([x0, x1, x2, x3], -1)
436
+
437
+ x = self.norm(x)
438
+ x = self.reduction(x)
439
+ return x
440
+
441
+
442
+ def compute_mask(dims, window_size, shift_size, device):
443
+ """Computing region masks based on: "Liu et al.,
444
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
445
+ <https://arxiv.org/abs/2103.14030>"
446
+ https://github.com/microsoft/Swin-Transformer
447
+ Args:
448
+ dims: dimension values.
449
+ window_size: local window size.
450
+ shift_size: shift size.
451
+ device: device.
452
+ """
453
+
454
+ cnt = 0
455
+
456
+ if len(dims) == 3:
457
+ d, h, w = dims
458
+ img_mask = torch.zeros((1, d, h, w, 1), device=device)
459
+ for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
460
+ for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
461
+ for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
462
+ img_mask[:, d, h, w, :] = cnt
463
+ cnt += 1
464
+
465
+ elif len(dims) == 2:
466
+ h, w = dims
467
+ img_mask = torch.zeros((1, h, w, 1), device=device)
468
+ for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
469
+ for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
470
+ img_mask[:, h, w, :] = cnt
471
+ cnt += 1
472
+
473
+ mask_windows = window_partition(img_mask, window_size)
474
+ mask_windows = mask_windows.squeeze(-1)
475
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
476
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
477
+
478
+ return attn_mask
479
+
480
+
481
+ class BasicLayer(nn.Module):
482
+ """
483
+ Basic Swin Transformer layer in one stage based on: "Liu et al.,
484
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
485
+ <https://arxiv.org/abs/2103.14030>"
486
+ https://github.com/microsoft/Swin-Transformer
487
+ """
488
+
489
+ def __init__(
490
+ self,
491
+ dim: int,
492
+ depth: int,
493
+ num_heads: int,
494
+ window_size: Sequence[int],
495
+ drop_path: list,
496
+ mlp_ratio: float = 4.0,
497
+ qkv_bias: bool = False,
498
+ drop: float = 0.0,
499
+ attn_drop: float = 0.0,
500
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
501
+ downsample: isinstance = None, # type: ignore
502
+ use_checkpoint: bool = False,
503
+ ) -> None:
504
+ """
505
+ Args:
506
+ dim: number of feature channels.
507
+ depths: number of layers in each stage.
508
+ num_heads: number of attention heads.
509
+ window_size: local window size.
510
+ drop_path: stochastic depth rate.
511
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
512
+ qkv_bias: add a learnable bias to query, key, value.
513
+ drop: dropout rate.
514
+ attn_drop: attention dropout rate.
515
+ norm_layer: normalization layer.
516
+ downsample: downsample layer at the end of the layer.
517
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
518
+ """
519
+
520
+ super().__init__()
521
+ self.window_size = window_size
522
+ self.shift_size = tuple(i // 2 for i in window_size)
523
+ self.no_shift = tuple(0 for i in window_size)
524
+ self.depth = depth
525
+ self.use_checkpoint = use_checkpoint
526
+ self.blocks = nn.ModuleList(
527
+ [
528
+ SwinTransformerBlock(
529
+ dim=dim,
530
+ num_heads=num_heads,
531
+ window_size=self.window_size,
532
+ shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
533
+ mlp_ratio=mlp_ratio,
534
+ qkv_bias=qkv_bias,
535
+ drop=drop,
536
+ attn_drop=attn_drop,
537
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
538
+ norm_layer=norm_layer,
539
+ use_checkpoint=use_checkpoint,
540
+ )
541
+ for i in range(depth)
542
+ ]
543
+ )
544
+ self.downsample = downsample
545
+ if self.downsample is not None:
546
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
547
+
548
+ def forward(self, x):
549
+ x_shape = x.size()
550
+ if len(x_shape) == 5:
551
+ b, c, d, h, w = x_shape
552
+ window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
553
+ x = rearrange(x, "b c d h w -> b d h w c")
554
+ dp = int(np.ceil(d / window_size[0])) * window_size[0]
555
+ hp = int(np.ceil(h / window_size[1])) * window_size[1]
556
+ wp = int(np.ceil(w / window_size[2])) * window_size[2]
557
+ attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
558
+ for blk in self.blocks:
559
+ x = blk(x, attn_mask)
560
+ x = x.view(b, d, h, w, -1)
561
+ if self.downsample is not None:
562
+ x = self.downsample(x)
563
+ x = rearrange(x, "b d h w c -> b c d h w")
564
+
565
+ elif len(x_shape) == 4:
566
+ b, c, h, w = x_shape
567
+ window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
568
+ x = rearrange(x, "b c h w -> b h w c")
569
+ hp = int(np.ceil(h / window_size[0])) * window_size[0]
570
+ wp = int(np.ceil(w / window_size[1])) * window_size[1]
571
+ attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
572
+ for blk in self.blocks:
573
+ x = blk(x, attn_mask)
574
+ x = x.view(b, h, w, -1)
575
+ if self.downsample is not None:
576
+ x = self.downsample(x)
577
+ x = rearrange(x, "b h w c -> b c h w")
578
+ return x
579
+
580
+
581
+ class SwinTransformer(nn.Module):
582
+ """
583
+ Swin Transformer based on: "Liu et al.,
584
+ Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
585
+ <https://arxiv.org/abs/2103.14030>"
586
+ https://github.com/microsoft/Swin-Transformer
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ in_chans: int,
592
+ embed_dim: int,
593
+ window_size: Sequence[int],
594
+ patch_size: Sequence[int],
595
+ depths: Sequence[int],
596
+ num_heads: Sequence[int],
597
+ mlp_ratio: float = 4.0,
598
+ qkv_bias: bool = True,
599
+ drop_rate: float = 0.0,
600
+ attn_drop_rate: float = 0.0,
601
+ drop_path_rate: float = 0.0,
602
+ norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
603
+ patch_norm: bool = False,
604
+ use_checkpoint: bool = False,
605
+ spatial_dims: int = 3,
606
+ ) -> None:
607
+ """
608
+ Args:
609
+ in_chans: dimension of input channels.
610
+ embed_dim: number of linear projection output channels.
611
+ window_size: local window size.
612
+ patch_size: patch size.
613
+ depths: number of layers in each stage.
614
+ num_heads: number of attention heads.
615
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
616
+ qkv_bias: add a learnable bias to query, key, value.
617
+ drop_rate: dropout rate.
618
+ attn_drop_rate: attention dropout rate.
619
+ drop_path_rate: stochastic depth rate.
620
+ norm_layer: normalization layer.
621
+ patch_norm: add normalization after patch embedding.
622
+ use_checkpoint: use gradient checkpointing for reduced memory usage.
623
+ spatial_dims: spatial dimension.
624
+ """
625
+
626
+ super().__init__()
627
+ self.num_layers = len(depths)
628
+ self.embed_dim = embed_dim
629
+ self.patch_norm = patch_norm
630
+ self.window_size = window_size
631
+ self.patch_size = patch_size
632
+ self.patch_embed = PatchEmbed(
633
+ patch_size=self.patch_size,
634
+ in_chans=in_chans,
635
+ embed_dim=embed_dim,
636
+ norm_layer=norm_layer if self.patch_norm else None, # type: ignore
637
+ spatial_dims=spatial_dims,
638
+ )
639
+ self.pos_drop = nn.Dropout(p=drop_rate)
640
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
641
+ # self.layers1 = nn.ModuleList()
642
+ # self.layers2 = nn.ModuleList()
643
+ # self.layers3 = nn.ModuleList()
644
+ # self.layers4 = nn.ModuleList()
645
+ self.layers = nn.ModuleList()
646
+ for i_layer in range(self.num_layers):
647
+ layer = BasicLayer(
648
+ dim=int(embed_dim * 2**i_layer),
649
+ depth=depths[i_layer],
650
+ num_heads=num_heads[i_layer],
651
+ window_size=self.window_size,
652
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
653
+ mlp_ratio=mlp_ratio,
654
+ qkv_bias=qkv_bias,
655
+ drop=drop_rate,
656
+ attn_drop=attn_drop_rate,
657
+ norm_layer=norm_layer,
658
+ downsample=PatchMerging,
659
+ use_checkpoint=use_checkpoint,
660
+ )
661
+ self.layers.append(layer)
662
+ # if i_layer == 0:
663
+ # self.layers1.append(layer)
664
+ # elif i_layer == 1:
665
+ # self.layers2.append(layer)
666
+ # elif i_layer == 2:
667
+ # self.layers3.append(layer)
668
+ # elif i_layer == 3:
669
+ # self.layers4.append(layer)
670
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
671
+
672
+ def proj_out(self, x, normalize=False):
673
+ if normalize:
674
+ x_shape = x.size()
675
+ if len(x_shape) == 5:
676
+ n, ch, d, h, w = x_shape
677
+ x = rearrange(x, "n c d h w -> n d h w c")
678
+ x = F.layer_norm(x, [ch])
679
+ x = rearrange(x, "n d h w c -> n c d h w")
680
+ elif len(x_shape) == 4:
681
+ n, ch, h, w = x_shape
682
+ x = rearrange(x, "n c h w -> n h w c")
683
+ x = F.layer_norm(x, [ch])
684
+ x = rearrange(x, "n h w c -> n c h w")
685
+ return x
686
+
687
+ def forward(self, x, normalize=True):
688
+ # x input: [B*sample, C(1), H, W, D]
689
+ # x = rearrange(x, "b c h w d -> b c d h w")
690
+ # print('>> input: ', x.shape)
691
+ x = self.patch_embed(x)
692
+ # print('>> patch_embed: ', x.shape)
693
+ x = self.pos_drop(x)
694
+ for layer in self.layers:
695
+ x = layer(x.contiguous())
696
+ # print('>> layer: ', x.shape)
697
+ return x
698
+ # # x0_out = self.proj_out(x0, normalize)
699
+ # x1 = self.layers1[0](x0.contiguous())
700
+ # # x1_out = self.proj_out(x1, normalize)
701
+ # x2 = self.layers2[0](x1.contiguous())
702
+ # # x2_out = self.proj_out(x2, normalize)
703
+ # x3 = self.layers3[0](x2.contiguous())
704
+ # # x3_out = self.proj_out(x3, normalize)
705
+ # x4 = self.layers4[0](x3.contiguous())
706
+ # # x4_out = self.proj_out(x4, normalize)
707
+ # # return [x0_out, x1_out, x2_out, x3_out, x4_out]
708
+
709
+
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from typing import Any, Optional, Tuple, Type
12
+
13
+ from .common import LayerNorm2d
14
+ import os
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int, int],
21
+ input_image_size: Tuple[int, int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
+ self.point_embeddings = nn.ModuleList(point_embeddings)
48
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
+
50
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
51
+ self.mask_downscaling = nn.Sequential(
52
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
+ LayerNorm2d(mask_in_chans // 4),
54
+ activation(),
55
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
+ LayerNorm2d(mask_in_chans),
57
+ activation(),
58
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
+ )
60
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
61
+
62
+ def get_dense_pe(self) -> torch.Tensor:
63
+ """
64
+ Returns the positional encoding used to encode point prompts,
65
+ applied to a dense set of points the shape of the image encoding.
66
+
67
+ Returns:
68
+ torch.Tensor: Positional encoding with shape
69
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
70
+ """
71
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
+
73
+ def _embed_points(
74
+ self,
75
+ points: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad: bool,
78
+ ) -> torch.Tensor:
79
+ """Embeds point prompts."""
80
+ points = points + 0.5 # Shift to center of pixel
81
+ if pad:
82
+ padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
83
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
+ points = torch.cat([points, padding_point], dim=1)
85
+ labels = torch.cat([labels, padding_label], dim=1)
86
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
+ point_embedding[labels == -1] = 0.0
88
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
89
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
90
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
91
+ return point_embedding
92
+
93
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
+ """Embeds box prompts."""
95
+ boxes = boxes + 0.5 # Shift to center of pixel
96
+ coords = boxes.reshape(-1, 2, 3)
97
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
+ return corner_embedding
101
+
102
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
+ """Embeds mask inputs."""
104
+ mask_embedding = self.mask_downscaling(masks)
105
+ return mask_embedding
106
+
107
+ def _get_batch_size(
108
+ self,
109
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
+ boxes: Optional[torch.Tensor],
111
+ masks: Optional[torch.Tensor],
112
+ text_embedding: Optional[torch.Tensor],
113
+ ) -> int:
114
+ """
115
+ Gets the batch size of the output given the batch size of the input prompts.
116
+ """
117
+ if points is not None:
118
+ return points[0].shape[0]
119
+ elif boxes is not None:
120
+ return boxes.shape[0]
121
+ elif masks is not None:
122
+ return masks.shape[0]
123
+ elif text_embedding is not None:
124
+ return text_embedding.shape[0]
125
+ else:
126
+ return 1
127
+
128
+ def _get_device(self) -> torch.device:
129
+ return self.point_embeddings[0].weight.device
130
+
131
+ def forward(
132
+ self,
133
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
134
+ boxes: Optional[torch.Tensor],
135
+ masks: Optional[torch.Tensor],
136
+ text_embedding: Optional[torch.Tensor],
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """
139
+ Embeds different types of prompts, returning both sparse and dense
140
+ embeddings.
141
+
142
+ Arguments:
143
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
144
+ and labels to embed.
145
+ boxes (torch.Tensor or none): boxes to embed
146
+ masks (torch.Tensor or none): masks to embed
147
+ text: test prompt (B, 768)
148
+
149
+ Returns:
150
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
151
+ BxNx(embed_dim), where N is determined by the number of input points
152
+ and boxes.
153
+ torch.Tensor: dense embeddings for the masks, in the shape
154
+ Bx(embed_dim)x(embed_H)x(embed_W)
155
+ """
156
+ # print('prompt encoder here...')
157
+
158
+ bs = self._get_batch_size(points, boxes, masks, text_embedding)
159
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
160
+ # print('sparse_embeddings ', sparse_embeddings.shape)
161
+ if points is not None:
162
+ coords, labels = points
163
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
164
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
165
+
166
+ if boxes is not None:
167
+ box_embeddings = self._embed_boxes(boxes)
168
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
169
+
170
+ if text_embedding is not None:
171
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
172
+
173
+ # print('box_embeddings ', box_embeddings.shape)
174
+ # print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
175
+
176
+ if masks is not None:
177
+ dense_embeddings = self._embed_masks(masks)
178
+ else:
179
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
180
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]
181
+ )
182
+ # print('dense_embeddings ', dense_embeddings.shape)
183
+ return sparse_embeddings, dense_embeddings
184
+
185
+
186
+ class PositionEmbeddingRandom(nn.Module):
187
+ """
188
+ Positional encoding using random spatial frequencies.
189
+ """
190
+
191
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
192
+ super().__init__()
193
+ if scale is None or scale <= 0.0:
194
+ scale = 1.0
195
+ self.register_buffer(
196
+ "positional_encoding_gaussian_matrix",
197
+ scale * torch.randn((3, num_pos_feats)),
198
+ )
199
+
200
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
201
+ """Positionally encode points that are normalized to [0,1]."""
202
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
203
+ coords = 2 * coords - 1
204
+ coords = coords @ self.positional_encoding_gaussian_matrix
205
+ coords = 2 * np.pi * coords
206
+ # outputs d_1 x ... x d_n x C shape
207
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
208
+
209
+ def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
210
+ """Generate positional encoding for a grid of the specified size."""
211
+ h, w, d = size
212
+ device: Any = self.positional_encoding_gaussian_matrix.device
213
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
214
+ y_embed = grid.cumsum(dim=0) - 0.5
215
+ x_embed = grid.cumsum(dim=1) - 0.5
216
+ z_embed = grid.cumsum(dim=2) - 0.5
217
+ y_embed = y_embed / h
218
+ x_embed = x_embed / w
219
+ z_embed = z_embed / d
220
+
221
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
222
+ return pe.permute(3, 0, 1, 2) # C x H x W x D
223
+
224
+ def forward_with_coords(
225
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
226
+ ) -> torch.Tensor:
227
+ """Positionally encode points that are not normalized to [0,1]."""
228
+ coords = coords_input.clone()
229
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
230
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
231
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
232
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
model/segment_anything_volumetric/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (394 Bytes). View file
 
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (424 Bytes). View file
 
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc ADDED
Binary file (1.77 kB). View file
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc ADDED
Binary file (21.5 kB). View file
 
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc ADDED
Binary file (5.5 kB). View file