yuxin
commited on
Commit
·
2af4882
1
Parent(s):
6c27712
init segvol
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +295 -0
- model/LICENSE +21 -0
- model/README.md +74 -0
- model/__pycache__/inference_cpu.cpython-39.pyc +0 -0
- model/asset/model.png +0 -0
- model/asset/overview back.png +0 -0
- model/asset/overview.png +0 -0
- model/config/clip/config.json +157 -0
- model/config/clip/special_tokens_map.json +1 -0
- model/config/clip/tokenizer.json +0 -0
- model/config/clip/tokenizer_config.json +1 -0
- model/config/clip/vocab.json +0 -0
- model/config/config_demo.json +8 -0
- model/data_process/__pycache__/demo_data_process.cpython-39.pyc +0 -0
- model/data_process/demo_data_process.py +91 -0
- model/inference_cpu.py +171 -0
- model/inference_demo.py +219 -0
- model/network/__pycache__/model.cpython-39.pyc +0 -0
- model/network/model.py +91 -0
- model/script/inference_demo.sh +8 -0
- model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py +172 -0
- model/segment_anything_volumetric/__init__.py +12 -0
- model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/automatic_mask_generator.py +372 -0
- model/segment_anything_volumetric/build_sam.py +111 -0
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py +709 -0
- model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py +232 -0
- model/segment_anything_volumetric/modeling/__init__.py +11 -0
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/prompt_encoder.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/prompt_encoder.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/sam.cpython-310.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/sam.cpython-39.pyc +0 -0
- model/segment_anything_volumetric/modeling/__pycache__/transformer.cpython-310.pyc +0 -0
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.85 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_drawable_canvas import st_canvas
|
3 |
+
from streamlit_image_coordinates import streamlit_image_coordinates
|
4 |
+
|
5 |
+
|
6 |
+
from model.data_process.demo_data_process import process_ct_gt
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
+
import monai.transforms as transforms
|
11 |
+
from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
|
12 |
+
|
13 |
+
print('script run')
|
14 |
+
|
15 |
+
#############################################
|
16 |
+
# init session_state
|
17 |
+
if 'option' not in st.session_state:
|
18 |
+
st.session_state.option = None
|
19 |
+
if 'text_prompt' not in st.session_state:
|
20 |
+
st.session_state.text_prompt = None
|
21 |
+
|
22 |
+
if 'reset_demo_case' not in st.session_state:
|
23 |
+
st.session_state.reset_demo_case = False
|
24 |
+
|
25 |
+
if 'preds_3D' not in st.session_state:
|
26 |
+
st.session_state.preds_3D = None
|
27 |
+
|
28 |
+
if 'data_item' not in st.session_state:
|
29 |
+
st.session_state.data_item = None
|
30 |
+
|
31 |
+
if 'points' not in st.session_state:
|
32 |
+
st.session_state.points = []
|
33 |
+
|
34 |
+
if 'use_text_prompt' not in st.session_state:
|
35 |
+
st.session_state.use_text_prompt = False
|
36 |
+
|
37 |
+
if 'use_point_prompt' not in st.session_state:
|
38 |
+
st.session_state.use_point_prompt = False
|
39 |
+
|
40 |
+
if 'use_box_prompt' not in st.session_state:
|
41 |
+
st.session_state.use_box_prompt = False
|
42 |
+
|
43 |
+
if 'rectangle_3Dbox' not in st.session_state:
|
44 |
+
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
45 |
+
|
46 |
+
if 'irregular_box' not in st.session_state:
|
47 |
+
st.session_state.irregular_box = False
|
48 |
+
|
49 |
+
if 'running' not in st.session_state:
|
50 |
+
st.session_state.running = False
|
51 |
+
|
52 |
+
if 'transparency' not in st.session_state:
|
53 |
+
st.session_state.transparency = 0.25
|
54 |
+
|
55 |
+
case_list = [
|
56 |
+
'model/asset/Case_image_00001_0000.nii.gz',
|
57 |
+
'cases/FLARE22_Tr_0002_0000.nii.gz',
|
58 |
+
'cases/FLARE22_Tr_0005_0000.nii.gz',
|
59 |
+
'cases/FLARE22_Tr_0034_0000.nii.gz',
|
60 |
+
'cases/FLARE22_Tr_0045_0000.nii.gz'
|
61 |
+
]
|
62 |
+
#############################################
|
63 |
+
|
64 |
+
#############################################
|
65 |
+
# reset functions
|
66 |
+
def clear_prompts():
|
67 |
+
st.session_state.points = []
|
68 |
+
st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
|
69 |
+
|
70 |
+
def reset_demo_case():
|
71 |
+
st.session_state.data_item = None
|
72 |
+
st.session_state.reset_demo_case = True
|
73 |
+
clear_prompts()
|
74 |
+
|
75 |
+
def clear_file():
|
76 |
+
st.session_state.option = None
|
77 |
+
process_ct_gt.clear()
|
78 |
+
reset_demo_case()
|
79 |
+
clear_prompts()
|
80 |
+
|
81 |
+
#############################################
|
82 |
+
|
83 |
+
st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
|
84 |
+
# modify demo case here
|
85 |
+
demo_type = st.radio(
|
86 |
+
"Demo case source",
|
87 |
+
["Select", "Upload"],
|
88 |
+
on_change=clear_file
|
89 |
+
)
|
90 |
+
|
91 |
+
if demo_type=="Select":
|
92 |
+
uploaded_file = st.selectbox(
|
93 |
+
"Select a demo case",
|
94 |
+
case_list,
|
95 |
+
index=None,
|
96 |
+
placeholder="Select a demo case...",
|
97 |
+
on_change=reset_demo_case
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
|
101 |
+
|
102 |
+
st.session_state.option = uploaded_file
|
103 |
+
|
104 |
+
if st.session_state.option is not None and \
|
105 |
+
st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
|
106 |
+
|
107 |
+
st.session_state.data_item = process_ct_gt(st.session_state.option)
|
108 |
+
st.session_state.reset_demo_case = False
|
109 |
+
st.session_state.preds_3D = None
|
110 |
+
|
111 |
+
prompt_col1, prompt_col2 = st.columns(2)
|
112 |
+
|
113 |
+
with prompt_col1:
|
114 |
+
st.session_state.use_text_prompt = st.toggle('Sematic prompt')
|
115 |
+
text_prompt_type = st.radio(
|
116 |
+
"Sematic prompt type",
|
117 |
+
["Predefined", "Custom"],
|
118 |
+
disabled=(not st.session_state.use_text_prompt)
|
119 |
+
)
|
120 |
+
if text_prompt_type == "Predefined":
|
121 |
+
pre_text = st.selectbox(
|
122 |
+
"Predefined anatomical category:",
|
123 |
+
['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
|
124 |
+
index=None,
|
125 |
+
disabled=(not st.session_state.use_text_prompt)
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
|
129 |
+
disabled=(not st.session_state.use_text_prompt))
|
130 |
+
if pre_text is None or len(pre_text) > 0:
|
131 |
+
st.session_state.text_prompt = pre_text
|
132 |
+
else:
|
133 |
+
st.session_state.text_prompt = None
|
134 |
+
|
135 |
+
|
136 |
+
with prompt_col2:
|
137 |
+
spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
|
138 |
+
spatial_prompt = st.radio(
|
139 |
+
"Spatial prompt type",
|
140 |
+
["Point prompt", "Box prompt"],
|
141 |
+
on_change=clear_prompts,
|
142 |
+
disabled=(not spatial_prompt_on))
|
143 |
+
|
144 |
+
if spatial_prompt == "Point prompt":
|
145 |
+
st.session_state.use_point_prompt = True
|
146 |
+
st.session_state.use_box_prompt = False
|
147 |
+
elif spatial_prompt == "Box prompt":
|
148 |
+
st.session_state.use_box_prompt = True
|
149 |
+
st.session_state.use_point_prompt = False
|
150 |
+
else:
|
151 |
+
st.session_state.use_point_prompt = False
|
152 |
+
st.session_state.use_box_prompt = False
|
153 |
+
|
154 |
+
if not spatial_prompt_on:
|
155 |
+
st.session_state.use_point_prompt = False
|
156 |
+
st.session_state.use_box_prompt = False
|
157 |
+
|
158 |
+
if not st.session_state.use_text_prompt:
|
159 |
+
st.session_state.text_prompt = None
|
160 |
+
|
161 |
+
if st.session_state.option is None:
|
162 |
+
st.write('please select demo case first')
|
163 |
+
else:
|
164 |
+
image_3D = st.session_state.data_item['z_image'][0].numpy()
|
165 |
+
col_control1, col_control2 = st.columns(2)
|
166 |
+
|
167 |
+
with col_control1:
|
168 |
+
selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 0, key='xy')
|
169 |
+
|
170 |
+
with col_control2:
|
171 |
+
selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 0, key='xz')
|
172 |
+
if st.session_state.use_box_prompt:
|
173 |
+
top, bottom = st.select_slider(
|
174 |
+
'Top and bottom of box',
|
175 |
+
options=range(0, 325),
|
176 |
+
value=(0, 324)
|
177 |
+
)
|
178 |
+
st.session_state.rectangle_3Dbox[0] = top
|
179 |
+
st.session_state.rectangle_3Dbox[3] = bottom
|
180 |
+
col_image1, col_image2 = st.columns(2)
|
181 |
+
|
182 |
+
if st.session_state.preds_3D is not None:
|
183 |
+
st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.5)
|
184 |
+
|
185 |
+
with col_image1:
|
186 |
+
|
187 |
+
image_z_array = image_3D[selected_index_z]
|
188 |
+
|
189 |
+
preds_z_array = None
|
190 |
+
if st.session_state.preds_3D is not None:
|
191 |
+
preds_z_array = st.session_state.preds_3D[selected_index_z]
|
192 |
+
|
193 |
+
image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
|
194 |
+
|
195 |
+
|
196 |
+
if st.session_state.use_point_prompt:
|
197 |
+
value_xy = streamlit_image_coordinates(image_z, width=325)
|
198 |
+
|
199 |
+
if value_xy is not None:
|
200 |
+
point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
|
201 |
+
if len(st.session_state.points) >= 3:
|
202 |
+
st.warning('Max point num is 3', icon="⚠️")
|
203 |
+
elif point_ax_xy not in st.session_state.points:
|
204 |
+
st.session_state.points.append(point_ax_xy)
|
205 |
+
print('point_ax_xy add rerun')
|
206 |
+
st.rerun()
|
207 |
+
elif st.session_state.use_box_prompt:
|
208 |
+
canvas_result_xy = st_canvas(
|
209 |
+
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
210 |
+
stroke_width=3,
|
211 |
+
stroke_color='#2909F1',
|
212 |
+
background_image=image_z,
|
213 |
+
update_streamlit=True,
|
214 |
+
height=325,
|
215 |
+
width=325,
|
216 |
+
drawing_mode='transform',
|
217 |
+
point_display_radius=0,
|
218 |
+
key="canvas_xy",
|
219 |
+
initial_drawing=initial_rectangle,
|
220 |
+
display_toolbar=True
|
221 |
+
)
|
222 |
+
try:
|
223 |
+
print(canvas_result_xy.json_data['objects'][0]['angle'])
|
224 |
+
if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
|
225 |
+
st.warning('Rotating is undefined behavior', icon="⚠️")
|
226 |
+
st.session_state.irregular_box = True
|
227 |
+
else:
|
228 |
+
st.session_state.irregular_box = False
|
229 |
+
reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
|
230 |
+
except:
|
231 |
+
print('exception')
|
232 |
+
pass
|
233 |
+
else:
|
234 |
+
st.image(image_z, use_column_width=False)
|
235 |
+
|
236 |
+
with col_image2:
|
237 |
+
image_y_array = image_3D[:, selected_index_y, :]
|
238 |
+
|
239 |
+
preds_y_array = None
|
240 |
+
if st.session_state.preds_3D is not None:
|
241 |
+
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
|
242 |
+
|
243 |
+
image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
|
244 |
+
|
245 |
+
if st.session_state.use_point_prompt:
|
246 |
+
value_yz = streamlit_image_coordinates(image_y, width=325)
|
247 |
+
|
248 |
+
if value_yz is not None:
|
249 |
+
point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
|
250 |
+
if len(st.session_state.points) >= 3:
|
251 |
+
st.warning('Max point num is 3', icon="⚠️")
|
252 |
+
elif point_ax_xz not in st.session_state.points:
|
253 |
+
st.session_state.points.append(point_ax_xz)
|
254 |
+
print('point_ax_xz add rerun')
|
255 |
+
st.rerun()
|
256 |
+
elif st.session_state.use_box_prompt:
|
257 |
+
if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
|
258 |
+
draw = ImageDraw.Draw(image_y)
|
259 |
+
#rectangle xz view (upper-left and lower-right)
|
260 |
+
rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
|
261 |
+
(st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
|
262 |
+
# Draw the rectangle on the image
|
263 |
+
draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
|
264 |
+
st.image(image_y, use_column_width=False)
|
265 |
+
else:
|
266 |
+
st.image(image_y, use_column_width=False)
|
267 |
+
|
268 |
+
|
269 |
+
col1, col2, col3 = st.columns(3)
|
270 |
+
|
271 |
+
with col1:
|
272 |
+
if st.button("Clear", use_container_width=True,
|
273 |
+
disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
|
274 |
+
clear_prompts()
|
275 |
+
st.session_state.preds_3D = None
|
276 |
+
st.rerun()
|
277 |
+
|
278 |
+
with col3:
|
279 |
+
if st.button("Run", type="primary", use_container_width=True,
|
280 |
+
disabled=(
|
281 |
+
st.session_state.data_item is None or
|
282 |
+
(st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
|
283 |
+
st.session_state.irregular_box or
|
284 |
+
st.session_state.running
|
285 |
+
)):
|
286 |
+
st.session_state.running = True
|
287 |
+
st.rerun()
|
288 |
+
|
289 |
+
# if len(st.session_state.points) > 0:
|
290 |
+
# st.write(st.session_state.points)
|
291 |
+
|
292 |
+
if st.session_state.running:
|
293 |
+
st.session_state.running = False
|
294 |
+
run()
|
295 |
+
st.rerun()
|
model/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 BAAI-DCAI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
model/README.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SegVol: Universal and Interactive Volumetric Medical Image Segmentation
|
2 |
+
This repo is the official implementation of [SegVol: Universal and Interactive Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.13385).
|
3 |
+
|
4 |
+
## News🚀
|
5 |
+
(2023.11.24) *You can download weight files of SegVol and ViT(CTs pre-train) [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link).* 🔥
|
6 |
+
|
7 |
+
(2023.11.23) *The brief introduction and instruction have been uploaded.*
|
8 |
+
|
9 |
+
(2023.11.23) *The inference demo code has been uploaded.*
|
10 |
+
|
11 |
+
(2023.11.22) *The first edition of our paper has been uploaded to arXiv.* 📃
|
12 |
+
|
13 |
+
## Introduction
|
14 |
+
<img src="https://github.com/BAAI-DCAI/SegVol/blob/main/asset/overview.png" width="60%" height="60%">
|
15 |
+
|
16 |
+
The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
|
17 |
+
|
18 |
+
We will release SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k CTs).
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
### Requirements
|
22 |
+
The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or higher virsion) is needed first. Following install key requirements using commands:
|
23 |
+
|
24 |
+
```
|
25 |
+
pip install 'monai[all]==0.9.0'
|
26 |
+
pip install einops==0.6.1
|
27 |
+
pip install transformers==4.18.0
|
28 |
+
pip install matplotlib
|
29 |
+
```
|
30 |
+
### Config and run demo script
|
31 |
+
1. You can download the demo case [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link), or download the whole demo dataset [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
|
32 |
+
2. Please set CT path and Ground Truth path of the case in the [config_demo.json](https://github.com/BAAI-DCAI/SegVol/blob/main/config/config_demo.json).
|
33 |
+
3. After that, config the [inference_demo.sh](https://github.com/BAAI-DCAI/SegVol/blob/main/script/inference_demo.sh) file for execution:
|
34 |
+
|
35 |
+
- `$segvol_ckpt`: the path of SegVol's checkpoint (Download from [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link)).
|
36 |
+
|
37 |
+
- `$work_dir`: any path of folder you want to save the log files and visualizaion results.
|
38 |
+
|
39 |
+
4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** [here](https://github.com/BAAI-DCAI/SegVol/blob/35f3ff9c943a74f630e6948051a1fe21aaba91bc/inference_demo.py#L208C11-L208C11).
|
40 |
+
5. Now, just run `bash script/inference_demo.sh` to infer your demo case.
|
41 |
+
|
42 |
+
## Citation
|
43 |
+
If you find this repository helpful, please consider citing:
|
44 |
+
```
|
45 |
+
@misc{du2023segvol,
|
46 |
+
title={SegVol: Universal and Interactive Volumetric Medical Image Segmentation},
|
47 |
+
author={Yuxin Du and Fan Bai and Tiejun Huang and Bo Zhao},
|
48 |
+
year={2023},
|
49 |
+
eprint={2311.13385},
|
50 |
+
archivePrefix={arXiv},
|
51 |
+
primaryClass={cs.CV}
|
52 |
+
}
|
53 |
+
```
|
54 |
+
|
55 |
+
## Acknowledgement
|
56 |
+
Thanks for the following amazing works:
|
57 |
+
|
58 |
+
[HuggingFace](https://huggingface.co/).
|
59 |
+
|
60 |
+
[CLIP](https://github.com/openai/CLIP).
|
61 |
+
|
62 |
+
[MONAI](https://github.com/Project-MONAI/MONAI).
|
63 |
+
|
64 |
+
[Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.
|
65 |
+
|
66 |
+
[Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.
|
67 |
+
|
68 |
+
[Image by pch.vector](https://www.freepik.com/free-vector/different-phone-hand-gestures-set_9649376.htm#query=Vector%20touch%20screen%20hand%20gestures&position=4&from_view=search&track=ais) on Freepik.
|
69 |
+
|
70 |
+
[Image by starline](https://www.freepik.com/free-vector/set-three-light-bulb-represent-effective-business-idea-concept_37588597.htm#query=idea&position=0&from_view=search&track=sph) on Freepik.
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
model/__pycache__/inference_cpu.cpython-39.pyc
ADDED
Binary file (4.67 kB). View file
|
|
model/asset/model.png
ADDED
model/asset/overview back.png
ADDED
model/asset/overview.png
ADDED
model/config/clip/config.json
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-base-patch32",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPModel"
|
5 |
+
],
|
6 |
+
"initializer_factor": 1.0,
|
7 |
+
"logit_scale_init_value": 2.6592,
|
8 |
+
"model_type": "clip",
|
9 |
+
"projection_dim": 512,
|
10 |
+
"text_config": {
|
11 |
+
"_name_or_path": "",
|
12 |
+
"add_cross_attention": false,
|
13 |
+
"architectures": null,
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"bad_words_ids": null,
|
16 |
+
"bos_token_id": 0,
|
17 |
+
"chunk_size_feed_forward": 0,
|
18 |
+
"cross_attention_hidden_size": null,
|
19 |
+
"decoder_start_token_id": null,
|
20 |
+
"diversity_penalty": 0.0,
|
21 |
+
"do_sample": false,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"early_stopping": false,
|
24 |
+
"encoder_no_repeat_ngram_size": 0,
|
25 |
+
"eos_token_id": 2,
|
26 |
+
"finetuning_task": null,
|
27 |
+
"forced_bos_token_id": null,
|
28 |
+
"forced_eos_token_id": null,
|
29 |
+
"hidden_act": "quick_gelu",
|
30 |
+
"hidden_size": 512,
|
31 |
+
"id2label": {
|
32 |
+
"0": "LABEL_0",
|
33 |
+
"1": "LABEL_1"
|
34 |
+
},
|
35 |
+
"initializer_factor": 1.0,
|
36 |
+
"initializer_range": 0.02,
|
37 |
+
"intermediate_size": 2048,
|
38 |
+
"is_decoder": false,
|
39 |
+
"is_encoder_decoder": false,
|
40 |
+
"label2id": {
|
41 |
+
"LABEL_0": 0,
|
42 |
+
"LABEL_1": 1
|
43 |
+
},
|
44 |
+
"layer_norm_eps": 1e-05,
|
45 |
+
"length_penalty": 1.0,
|
46 |
+
"max_length": 20,
|
47 |
+
"max_position_embeddings": 77,
|
48 |
+
"min_length": 0,
|
49 |
+
"model_type": "clip_text_model",
|
50 |
+
"no_repeat_ngram_size": 0,
|
51 |
+
"num_attention_heads": 8,
|
52 |
+
"num_beam_groups": 1,
|
53 |
+
"num_beams": 1,
|
54 |
+
"num_hidden_layers": 12,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"output_attentions": false,
|
57 |
+
"output_hidden_states": false,
|
58 |
+
"output_scores": false,
|
59 |
+
"pad_token_id": 1,
|
60 |
+
"prefix": null,
|
61 |
+
"projection_dim": 512,
|
62 |
+
"problem_type": null,
|
63 |
+
"pruned_heads": {},
|
64 |
+
"remove_invalid_values": false,
|
65 |
+
"repetition_penalty": 1.0,
|
66 |
+
"return_dict": true,
|
67 |
+
"return_dict_in_generate": false,
|
68 |
+
"sep_token_id": null,
|
69 |
+
"task_specific_params": null,
|
70 |
+
"temperature": 1.0,
|
71 |
+
"tie_encoder_decoder": false,
|
72 |
+
"tie_word_embeddings": true,
|
73 |
+
"tokenizer_class": null,
|
74 |
+
"top_k": 50,
|
75 |
+
"top_p": 1.0,
|
76 |
+
"torch_dtype": null,
|
77 |
+
"torchscript": false,
|
78 |
+
"transformers_version": "4.16.0.dev0",
|
79 |
+
"use_bfloat16": false,
|
80 |
+
"vocab_size": 49408
|
81 |
+
},
|
82 |
+
"text_config_dict": null,
|
83 |
+
"transformers_version": null,
|
84 |
+
"vision_config": {
|
85 |
+
"_name_or_path": "",
|
86 |
+
"add_cross_attention": false,
|
87 |
+
"architectures": null,
|
88 |
+
"attention_dropout": 0.0,
|
89 |
+
"bad_words_ids": null,
|
90 |
+
"bos_token_id": null,
|
91 |
+
"chunk_size_feed_forward": 0,
|
92 |
+
"cross_attention_hidden_size": null,
|
93 |
+
"decoder_start_token_id": null,
|
94 |
+
"diversity_penalty": 0.0,
|
95 |
+
"do_sample": false,
|
96 |
+
"dropout": 0.0,
|
97 |
+
"early_stopping": false,
|
98 |
+
"encoder_no_repeat_ngram_size": 0,
|
99 |
+
"eos_token_id": null,
|
100 |
+
"finetuning_task": null,
|
101 |
+
"forced_bos_token_id": null,
|
102 |
+
"forced_eos_token_id": null,
|
103 |
+
"hidden_act": "quick_gelu",
|
104 |
+
"hidden_size": 768,
|
105 |
+
"id2label": {
|
106 |
+
"0": "LABEL_0",
|
107 |
+
"1": "LABEL_1"
|
108 |
+
},
|
109 |
+
"image_size": 224,
|
110 |
+
"initializer_factor": 1.0,
|
111 |
+
"initializer_range": 0.02,
|
112 |
+
"intermediate_size": 3072,
|
113 |
+
"is_decoder": false,
|
114 |
+
"is_encoder_decoder": false,
|
115 |
+
"label2id": {
|
116 |
+
"LABEL_0": 0,
|
117 |
+
"LABEL_1": 1
|
118 |
+
},
|
119 |
+
"layer_norm_eps": 1e-05,
|
120 |
+
"length_penalty": 1.0,
|
121 |
+
"max_length": 20,
|
122 |
+
"min_length": 0,
|
123 |
+
"model_type": "clip_vision_model",
|
124 |
+
"no_repeat_ngram_size": 0,
|
125 |
+
"num_attention_heads": 12,
|
126 |
+
"num_beam_groups": 1,
|
127 |
+
"num_beams": 1,
|
128 |
+
"num_hidden_layers": 12,
|
129 |
+
"num_return_sequences": 1,
|
130 |
+
"output_attentions": false,
|
131 |
+
"output_hidden_states": false,
|
132 |
+
"output_scores": false,
|
133 |
+
"pad_token_id": null,
|
134 |
+
"patch_size": 32,
|
135 |
+
"prefix": null,
|
136 |
+
"projection_dim" : 512,
|
137 |
+
"problem_type": null,
|
138 |
+
"pruned_heads": {},
|
139 |
+
"remove_invalid_values": false,
|
140 |
+
"repetition_penalty": 1.0,
|
141 |
+
"return_dict": true,
|
142 |
+
"return_dict_in_generate": false,
|
143 |
+
"sep_token_id": null,
|
144 |
+
"task_specific_params": null,
|
145 |
+
"temperature": 1.0,
|
146 |
+
"tie_encoder_decoder": false,
|
147 |
+
"tie_word_embeddings": true,
|
148 |
+
"tokenizer_class": null,
|
149 |
+
"top_k": 50,
|
150 |
+
"top_p": 1.0,
|
151 |
+
"torch_dtype": null,
|
152 |
+
"torchscript": false,
|
153 |
+
"transformers_version": "4.16.0.dev0",
|
154 |
+
"use_bfloat16": false
|
155 |
+
},
|
156 |
+
"vision_config_dict": null
|
157 |
+
}
|
model/config/clip/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
model/config/clip/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/config/clip/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
|
model/config/clip/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/config/config_demo.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_name": "AbdomenCT-1k",
|
3 |
+
"categories": ["liver", "kidney", "spleen", "pancreas"],
|
4 |
+
"demo_case": {
|
5 |
+
"ct_path": "path/to/Case_image",
|
6 |
+
"gt_path": "path/to/Case_label"
|
7 |
+
}
|
8 |
+
}
|
model/data_process/__pycache__/demo_data_process.cpython-39.pyc
ADDED
Binary file (3.26 kB). View file
|
|
model/data_process/demo_data_process.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import monai.transforms as transforms
|
3 |
+
import streamlit as st
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
class MinMaxNormalization(transforms.Transform):
|
7 |
+
def __call__(self, data):
|
8 |
+
d = dict(data)
|
9 |
+
k = "image"
|
10 |
+
d[k] = d[k] - d[k].min()
|
11 |
+
d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
|
12 |
+
return d
|
13 |
+
|
14 |
+
class DimTranspose(transforms.Transform):
|
15 |
+
def __init__(self, keys):
|
16 |
+
self.keys = keys
|
17 |
+
|
18 |
+
def __call__(self, data):
|
19 |
+
d = dict(data)
|
20 |
+
for key in self.keys:
|
21 |
+
d[key] = np.swapaxes(d[key], -1, -3)
|
22 |
+
return d
|
23 |
+
|
24 |
+
class ForegroundNormalization(transforms.Transform):
|
25 |
+
def __init__(self, keys):
|
26 |
+
self.keys = keys
|
27 |
+
|
28 |
+
def __call__(self, data):
|
29 |
+
d = dict(data)
|
30 |
+
|
31 |
+
for key in self.keys:
|
32 |
+
d[key] = self.normalize(d[key])
|
33 |
+
return d
|
34 |
+
|
35 |
+
def normalize(self, ct_narray):
|
36 |
+
ct_voxel_ndarray = ct_narray.copy()
|
37 |
+
ct_voxel_ndarray = ct_voxel_ndarray.flatten()
|
38 |
+
thred = np.mean(ct_voxel_ndarray)
|
39 |
+
voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
|
40 |
+
upper_bound = np.percentile(voxel_filtered, 99.95)
|
41 |
+
lower_bound = np.percentile(voxel_filtered, 00.05)
|
42 |
+
mean = np.mean(voxel_filtered)
|
43 |
+
std = np.std(voxel_filtered)
|
44 |
+
### transform ###
|
45 |
+
ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
|
46 |
+
ct_narray = (ct_narray - mean) / max(std, 1e-8)
|
47 |
+
return ct_narray
|
48 |
+
|
49 |
+
@st.cache_data
|
50 |
+
def process_ct_gt(case_path, spatial_size=(32,256,256)):
|
51 |
+
if case_path is None:
|
52 |
+
return None
|
53 |
+
print('Data preprocessing...')
|
54 |
+
# transform
|
55 |
+
img_loader = transforms.LoadImage(dtype=np.float32)
|
56 |
+
transform = transforms.Compose(
|
57 |
+
[
|
58 |
+
transforms.Orientationd(keys=["image"], axcodes="RAS"),
|
59 |
+
ForegroundNormalization(keys=["image"]),
|
60 |
+
DimTranspose(keys=["image"]),
|
61 |
+
MinMaxNormalization(),
|
62 |
+
transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
|
63 |
+
transforms.CropForegroundd(keys=["image"], source_key="image"),
|
64 |
+
transforms.ToTensord(keys=["image"]),
|
65 |
+
]
|
66 |
+
)
|
67 |
+
zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
|
68 |
+
z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
|
69 |
+
###
|
70 |
+
item = {}
|
71 |
+
# generate ct_voxel_ndarray
|
72 |
+
if type(case_path) is str:
|
73 |
+
ct_voxel_ndarray, _ = img_loader(case_path)
|
74 |
+
else:
|
75 |
+
bytes_data = case_path.read()
|
76 |
+
with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
|
77 |
+
tmp.write(bytes_data)
|
78 |
+
tmp.seek(0)
|
79 |
+
ct_voxel_ndarray, _ = img_loader(tmp.name)
|
80 |
+
ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
|
81 |
+
ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
|
82 |
+
item['image'] = ct_voxel_ndarray
|
83 |
+
|
84 |
+
# transform
|
85 |
+
item = transform(item)
|
86 |
+
item_zoom_out = zoom_out_transform(item)
|
87 |
+
item['zoom_out_image'] = item_zoom_out['image']
|
88 |
+
|
89 |
+
item_z = z_transform(item)
|
90 |
+
item['z_image'] = item_z['image']
|
91 |
+
return item
|
model/inference_cpu.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import json
|
6 |
+
import monai.transforms as transforms
|
7 |
+
|
8 |
+
from model.segment_anything_volumetric import sam_model_registry
|
9 |
+
from model.network.model import SegVol
|
10 |
+
from model.data_process.demo_data_process import process_ct_gt
|
11 |
+
from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
12 |
+
from model.utils.visualize import draw_result
|
13 |
+
import streamlit as st
|
14 |
+
|
15 |
+
def set_parse():
|
16 |
+
# %% set up parser
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--test_mode", default=True, type=bool)
|
19 |
+
parser.add_argument("--resume", type = str, default = 'model/asset/SegVol_v1.pth')
|
20 |
+
parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
|
21 |
+
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
22 |
+
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
23 |
+
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
24 |
+
### demo
|
25 |
+
parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
|
26 |
+
args = parser.parse_args()
|
27 |
+
return args
|
28 |
+
|
29 |
+
def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
|
30 |
+
image_single_resize = image_resize
|
31 |
+
image_single = image[0,0]
|
32 |
+
ori_shape = image_single.shape
|
33 |
+
|
34 |
+
# generate prompts
|
35 |
+
text_single = None if text_prompt is None else [text_prompt]
|
36 |
+
points_single = None
|
37 |
+
box_single = None
|
38 |
+
|
39 |
+
if args.use_point_prompt:
|
40 |
+
point, point_label = point_prompt
|
41 |
+
points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
|
42 |
+
binary_points_resize = build_binary_points(point, point_label, ori_shape)
|
43 |
+
if args.use_box_prompt:
|
44 |
+
box_single = box_prompt.unsqueeze(0).float()
|
45 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=ori_shape)
|
46 |
+
|
47 |
+
####################
|
48 |
+
# zoom-out inference:
|
49 |
+
print('--- zoom out inference ---')
|
50 |
+
print(text_single)
|
51 |
+
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
52 |
+
with torch.no_grad():
|
53 |
+
logits_global_single = segvol_model(image_single_resize,
|
54 |
+
text=text_single,
|
55 |
+
boxes=box_single,
|
56 |
+
points=points_single)
|
57 |
+
|
58 |
+
# resize back global logits
|
59 |
+
logits_global_single = F.interpolate(
|
60 |
+
logits_global_single.cpu(),
|
61 |
+
size=ori_shape, mode='nearest')[0][0]
|
62 |
+
|
63 |
+
# build prompt reflection for zoom-in
|
64 |
+
if args.use_point_prompt:
|
65 |
+
binary_points = F.interpolate(
|
66 |
+
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
67 |
+
size=ori_shape, mode='nearest')[0][0]
|
68 |
+
if args.use_box_prompt:
|
69 |
+
binary_cube = F.interpolate(
|
70 |
+
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
71 |
+
size=ori_shape, mode='nearest')[0][0]
|
72 |
+
# draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
|
73 |
+
if not args.use_zoom_in:
|
74 |
+
return logits_global_single
|
75 |
+
|
76 |
+
####################
|
77 |
+
# zoom-in inference:
|
78 |
+
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
79 |
+
if min_d is None:
|
80 |
+
print('Fail to detect foreground!')
|
81 |
+
return logits_global_single
|
82 |
+
|
83 |
+
# Crop roi
|
84 |
+
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
85 |
+
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
86 |
+
|
87 |
+
assert not (args.use_box_prompt and args.use_point_prompt)
|
88 |
+
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
89 |
+
prompt_reflection = None
|
90 |
+
if args.use_box_prompt:
|
91 |
+
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
92 |
+
prompt_reflection = (
|
93 |
+
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
94 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
95 |
+
)
|
96 |
+
if args.use_point_prompt:
|
97 |
+
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
98 |
+
prompt_reflection = (
|
99 |
+
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
100 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
101 |
+
)
|
102 |
+
|
103 |
+
## inference
|
104 |
+
with torch.no_grad():
|
105 |
+
logits_single_cropped = sliding_window_inference(
|
106 |
+
image_single_cropped, prompt_reflection,
|
107 |
+
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
108 |
+
text=text_single,
|
109 |
+
use_box=args.use_box_prompt,
|
110 |
+
use_point=args.use_point_prompt,
|
111 |
+
logits_global_single=logits_global_single,
|
112 |
+
)
|
113 |
+
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
114 |
+
if logits_single_cropped.shape != logits_global_single.shape:
|
115 |
+
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
116 |
+
|
117 |
+
return logits_global_single
|
118 |
+
|
119 |
+
@st.cache_resource
|
120 |
+
def build_model():
|
121 |
+
# build model
|
122 |
+
clip_ckpt = 'model/config/clip'
|
123 |
+
resume = 'model/asset/SegVol_v1.pth'
|
124 |
+
sam_model = sam_model_registry['vit']()
|
125 |
+
segvol_model = SegVol(
|
126 |
+
image_encoder=sam_model.image_encoder,
|
127 |
+
mask_decoder=sam_model.mask_decoder,
|
128 |
+
prompt_encoder=sam_model.prompt_encoder,
|
129 |
+
clip_ckpt=clip_ckpt,
|
130 |
+
roi_size=(32,256,256),
|
131 |
+
patch_size=(4,16,16),
|
132 |
+
test_mode=True,
|
133 |
+
)
|
134 |
+
segvol_model = torch.nn.DataParallel(segvol_model)
|
135 |
+
segvol_model.eval()
|
136 |
+
# load param
|
137 |
+
if os.path.isfile(resume):
|
138 |
+
## Map model to be loaded to specified single GPU
|
139 |
+
loc = 'cpu'
|
140 |
+
checkpoint = torch.load(resume, map_location=loc)
|
141 |
+
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
142 |
+
print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
|
143 |
+
print('model build done!')
|
144 |
+
return segvol_model
|
145 |
+
|
146 |
+
@st.cache_data
|
147 |
+
def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
|
148 |
+
# seg config
|
149 |
+
args = set_parse()
|
150 |
+
args.use_zoom_in = True
|
151 |
+
args.use_text_prompt = text_prompt is not None
|
152 |
+
args.use_box_prompt = _box_prompt is not None
|
153 |
+
args.use_point_prompt = _point_prompt is not None
|
154 |
+
|
155 |
+
segvol_model = build_model()
|
156 |
+
|
157 |
+
# run inference
|
158 |
+
logits = zoom_in_zoom_out(
|
159 |
+
args, segvol_model,
|
160 |
+
_image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
|
161 |
+
text_prompt, _point_prompt, _box_prompt)
|
162 |
+
print(logits.shape)
|
163 |
+
resize_transform = transforms.Compose([
|
164 |
+
transforms.AddChannel(),
|
165 |
+
transforms.Resize((325,325,325), mode='trilinear')
|
166 |
+
]
|
167 |
+
)
|
168 |
+
logits = resize_transform(logits)[0]
|
169 |
+
print(logits.shape)
|
170 |
+
return (torch.sigmoid(logits) > 0.5).int().numpy()
|
171 |
+
|
model/inference_demo.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import json
|
6 |
+
from segment_anything_volumetric import sam_model_registry
|
7 |
+
from network.model import SegVol
|
8 |
+
from data_process.demo_data_process import process_ct_gt
|
9 |
+
import monai.transforms as transforms
|
10 |
+
from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
|
11 |
+
from utils.visualize import draw_result
|
12 |
+
|
13 |
+
def set_parse():
|
14 |
+
# %% set up parser
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--test_mode", default=True, type=bool)
|
17 |
+
parser.add_argument("--resume", type = str, default = '')
|
18 |
+
parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
|
19 |
+
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
|
20 |
+
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
|
21 |
+
parser.add_argument('-work_dir', type=str, default='./work_dir')
|
22 |
+
### demo
|
23 |
+
parser.add_argument('--demo_config', type=str, required=True)
|
24 |
+
parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
|
25 |
+
args = parser.parse_args()
|
26 |
+
return args
|
27 |
+
|
28 |
+
def dice_score(preds, labels): # on GPU
|
29 |
+
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
30 |
+
predict = preds.view(1, -1)
|
31 |
+
target = labels.view(1, -1)
|
32 |
+
if target.shape[1] < 1e8:
|
33 |
+
predict = predict.cuda()
|
34 |
+
target = target.cuda()
|
35 |
+
predict = torch.sigmoid(predict)
|
36 |
+
predict = torch.where(predict > 0.5, 1., 0.)
|
37 |
+
|
38 |
+
tp = torch.sum(torch.mul(predict, target))
|
39 |
+
den = torch.sum(predict) + torch.sum(target) + 1
|
40 |
+
dice = 2 * tp / den
|
41 |
+
|
42 |
+
if target.shape[1] < 1e8:
|
43 |
+
predict = predict.cpu()
|
44 |
+
target = target.cpu()
|
45 |
+
return dice
|
46 |
+
|
47 |
+
def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
|
48 |
+
logits_labels_record = {}
|
49 |
+
image_single_resize = image_resize
|
50 |
+
image_single = image[0,0]
|
51 |
+
ori_shape = image_single.shape
|
52 |
+
for item_idx in range(len(categories)):
|
53 |
+
# get label to generate prompts
|
54 |
+
label_single = gt3D[0][item_idx]
|
55 |
+
label_single_resize = gt3D_resize[0][item_idx]
|
56 |
+
# skip meaningless categories
|
57 |
+
if torch.sum(label_single) == 0:
|
58 |
+
print('No object, skip')
|
59 |
+
continue
|
60 |
+
# generate prompts
|
61 |
+
text_single = categories[item_idx] if args.use_text_prompt else None
|
62 |
+
if categories is not None: print(f'inference |{categories[item_idx]}| target...')
|
63 |
+
points_single = None
|
64 |
+
box_single = None
|
65 |
+
if args.use_point_prompt:
|
66 |
+
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
67 |
+
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
68 |
+
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
|
69 |
+
if args.use_box_prompt:
|
70 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
71 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
|
72 |
+
|
73 |
+
####################
|
74 |
+
# zoom-out inference:
|
75 |
+
print('--- zoom out inference ---')
|
76 |
+
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
|
77 |
+
with torch.no_grad():
|
78 |
+
logits_global_single = segvol_model(image_single_resize.cuda(),
|
79 |
+
text=text_single,
|
80 |
+
boxes=box_single,
|
81 |
+
points=points_single)
|
82 |
+
|
83 |
+
# resize back global logits
|
84 |
+
logits_global_single = F.interpolate(
|
85 |
+
logits_global_single.cpu(),
|
86 |
+
size=ori_shape, mode='nearest')[0][0]
|
87 |
+
|
88 |
+
# build prompt reflection for zoom-in
|
89 |
+
if args.use_point_prompt:
|
90 |
+
binary_points = F.interpolate(
|
91 |
+
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
|
92 |
+
size=ori_shape, mode='nearest')[0][0]
|
93 |
+
if args.use_box_prompt:
|
94 |
+
binary_cube = F.interpolate(
|
95 |
+
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
|
96 |
+
size=ori_shape, mode='nearest')[0][0]
|
97 |
+
zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
98 |
+
logits_labels_record[categories[item_idx]] = (
|
99 |
+
zoom_out_dice,
|
100 |
+
image_single,
|
101 |
+
points_single,
|
102 |
+
box_single,
|
103 |
+
logits_global_single,
|
104 |
+
label_single)
|
105 |
+
print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
|
106 |
+
if not args.use_zoom_in:
|
107 |
+
continue
|
108 |
+
|
109 |
+
####################
|
110 |
+
# zoom-in inference:
|
111 |
+
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
|
112 |
+
if min_d is None:
|
113 |
+
print('Fail to detect foreground!')
|
114 |
+
continue
|
115 |
+
|
116 |
+
# Crop roi
|
117 |
+
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
118 |
+
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
|
119 |
+
|
120 |
+
assert not (args.use_box_prompt and args.use_point_prompt)
|
121 |
+
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
|
122 |
+
prompt_reflection = None
|
123 |
+
if args.use_box_prompt:
|
124 |
+
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
125 |
+
prompt_reflection = (
|
126 |
+
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
|
127 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
128 |
+
)
|
129 |
+
if args.use_point_prompt:
|
130 |
+
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
|
131 |
+
prompt_reflection = (
|
132 |
+
binary_points_cropped.unsqueeze(0).unsqueeze(0),
|
133 |
+
global_preds.unsqueeze(0).unsqueeze(0)
|
134 |
+
)
|
135 |
+
|
136 |
+
## inference
|
137 |
+
with torch.no_grad():
|
138 |
+
logits_single_cropped = sliding_window_inference(
|
139 |
+
image_single_cropped.cuda(), prompt_reflection,
|
140 |
+
args.spatial_size, 1, segvol_model, args.infer_overlap,
|
141 |
+
text=text_single,
|
142 |
+
use_box=args.use_box_prompt,
|
143 |
+
use_point=args.use_point_prompt,
|
144 |
+
)
|
145 |
+
logits_single_cropped = logits_single_cropped.cpu().squeeze()
|
146 |
+
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
|
147 |
+
zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
|
148 |
+
logits_labels_record[categories[item_idx]] = (
|
149 |
+
zoom_in_dice,
|
150 |
+
image_single,
|
151 |
+
points_single,
|
152 |
+
box_single,
|
153 |
+
logits_global_single,
|
154 |
+
label_single)
|
155 |
+
print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
|
156 |
+
return logits_labels_record
|
157 |
+
|
158 |
+
def inference_single_ct(args, segvol_model, data_item, categories):
|
159 |
+
segvol_model.eval()
|
160 |
+
image, gt3D = data_item["image"].float(), data_item["label"]
|
161 |
+
image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
|
162 |
+
|
163 |
+
logits_labels_record = zoom_in_zoom_out(
|
164 |
+
args, segvol_model,
|
165 |
+
image.unsqueeze(0), image_zoom_out.unsqueeze(0),
|
166 |
+
gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
|
167 |
+
categories=categories)
|
168 |
+
|
169 |
+
# visualize
|
170 |
+
if args.visualize:
|
171 |
+
for target, values in logits_labels_record.items():
|
172 |
+
dice_score, image, point_prompt, box_prompt, logits, labels = values
|
173 |
+
print(f'{target} result with Dice score {dice_score:.4f} visualizing')
|
174 |
+
draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
|
175 |
+
|
176 |
+
def main(args):
|
177 |
+
gpu = 0
|
178 |
+
torch.cuda.set_device(gpu)
|
179 |
+
# build model
|
180 |
+
sam_model = sam_model_registry['vit'](args=args)
|
181 |
+
segvol_model = SegVol(
|
182 |
+
image_encoder=sam_model.image_encoder,
|
183 |
+
mask_decoder=sam_model.mask_decoder,
|
184 |
+
prompt_encoder=sam_model.prompt_encoder,
|
185 |
+
clip_ckpt=args.clip_ckpt,
|
186 |
+
roi_size=args.spatial_size,
|
187 |
+
patch_size=args.patch_size,
|
188 |
+
test_mode=args.test_mode,
|
189 |
+
).cuda()
|
190 |
+
segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
|
191 |
+
|
192 |
+
# load param
|
193 |
+
if os.path.isfile(args.resume):
|
194 |
+
## Map model to be loaded to specified single GPU
|
195 |
+
loc = 'cuda:{}'.format(gpu)
|
196 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
197 |
+
segvol_model.load_state_dict(checkpoint['model'], strict=True)
|
198 |
+
print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
199 |
+
|
200 |
+
# load demo config
|
201 |
+
with open(args.demo_config, 'r') as file:
|
202 |
+
config_dict = json.load(file)
|
203 |
+
ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
|
204 |
+
|
205 |
+
# preprocess for data
|
206 |
+
data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
|
207 |
+
|
208 |
+
# seg config for prompt & zoom-in-zoom-out
|
209 |
+
args.use_zoom_in = True
|
210 |
+
args.use_text_prompt = True
|
211 |
+
args.use_box_prompt = True
|
212 |
+
args.use_point_prompt = False
|
213 |
+
args.visualize = False
|
214 |
+
|
215 |
+
inference_single_ct(args, segvol_model, data_item, categories)
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
args = set_parse()
|
219 |
+
main(args)
|
model/network/__pycache__/model.cpython-39.pyc
ADDED
Binary file (3.28 kB). View file
|
|
model/network/model.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
|
6 |
+
|
7 |
+
#%% set up model
|
8 |
+
class SegVol(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
image_encoder,
|
11 |
+
mask_decoder,
|
12 |
+
prompt_encoder,
|
13 |
+
clip_ckpt,
|
14 |
+
roi_size,
|
15 |
+
patch_size,
|
16 |
+
test_mode=False,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.image_encoder = image_encoder
|
20 |
+
self.mask_decoder = mask_decoder
|
21 |
+
self.prompt_encoder = prompt_encoder
|
22 |
+
self.text_encoder = TextEncoder(clip_ckpt)
|
23 |
+
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
24 |
+
self.test_mode = test_mode
|
25 |
+
|
26 |
+
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
27 |
+
bs = image.shape[0]
|
28 |
+
img_shape = (image.shape[2], image.shape[3], image.shape[4])
|
29 |
+
image_embedding, _ = self.image_encoder(image)
|
30 |
+
image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
|
31 |
+
int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
|
32 |
+
# test mode
|
33 |
+
if self.test_mode:
|
34 |
+
return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
|
35 |
+
# train mode
|
36 |
+
# future release
|
37 |
+
|
38 |
+
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
|
39 |
+
with torch.no_grad():
|
40 |
+
if boxes is not None:
|
41 |
+
if len(boxes.shape) == 2:
|
42 |
+
boxes = boxes[:, None, :] # (B, 1, 6)
|
43 |
+
if text is not None:
|
44 |
+
text_embedding = self.text_encoder(text) # (B, 768)
|
45 |
+
else:
|
46 |
+
text_embedding = None
|
47 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
48 |
+
points=points,
|
49 |
+
boxes=boxes,
|
50 |
+
masks=None,
|
51 |
+
text_embedding=text_embedding,
|
52 |
+
)
|
53 |
+
|
54 |
+
dense_pe = self.prompt_encoder.get_dense_pe()
|
55 |
+
low_res_masks, _ = self.mask_decoder(
|
56 |
+
image_embeddings=image_embedding,
|
57 |
+
text_embedding = text_embedding,
|
58 |
+
image_pe=dense_pe,
|
59 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
60 |
+
dense_prompt_embeddings=dense_embeddings,
|
61 |
+
multimask_output=False,
|
62 |
+
)
|
63 |
+
logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
|
64 |
+
return logits
|
65 |
+
|
66 |
+
class TextEncoder(nn.Module):
|
67 |
+
def __init__(self, clip_ckpt):
|
68 |
+
super().__init__()
|
69 |
+
config = CLIPTextConfig()
|
70 |
+
self.clip_text_model = CLIPTextModel(config)
|
71 |
+
self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
|
72 |
+
self.dim_align = nn.Linear(512, 768)
|
73 |
+
# freeze text encoder
|
74 |
+
for param in self.clip_text_model.parameters():
|
75 |
+
param.requires_grad = False
|
76 |
+
|
77 |
+
def organ2tokens(self, organ_names):
|
78 |
+
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
79 |
+
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
80 |
+
return tokens
|
81 |
+
|
82 |
+
def forward(self, text):
|
83 |
+
if text is None:
|
84 |
+
return None
|
85 |
+
if type(text) is str:
|
86 |
+
text = [text]
|
87 |
+
tokens = self.organ2tokens(text)
|
88 |
+
clip_outputs = self.clip_text_model(**tokens)
|
89 |
+
text_embedding = clip_outputs.pooler_output
|
90 |
+
text_embedding = self.dim_align(text_embedding)
|
91 |
+
return text_embedding
|
model/script/inference_demo.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export segvol_ckpt="path/to/SegVol_v1.pth"
|
2 |
+
export work_dir="path/to/work_dir"
|
3 |
+
export demo_config_path="./config/config_demo.json"
|
4 |
+
|
5 |
+
CUDA_VISIBLE_DEVICES=0 python inference_demo.py \
|
6 |
+
--resume $segvol_ckpt \
|
7 |
+
-work_dir $work_dir \
|
8 |
+
--demo_config $demo_config_path
|
model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from functools import partial
|
7 |
+
from pathlib import Path
|
8 |
+
import urllib.request
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .modeling import (
|
12 |
+
ImageEncoderViT,
|
13 |
+
MaskDecoder,
|
14 |
+
PromptEncoder,
|
15 |
+
Sam,
|
16 |
+
TwoWayTransformer,
|
17 |
+
)
|
18 |
+
|
19 |
+
from .modeling.image_encoder_swin import SwinTransformer
|
20 |
+
|
21 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
22 |
+
|
23 |
+
def build_sam_vit_h(checkpoint=None, image_size=1024):
|
24 |
+
return _build_sam(
|
25 |
+
encoder_embed_dim=1280,
|
26 |
+
encoder_depth=32,
|
27 |
+
encoder_num_heads=16,
|
28 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
29 |
+
checkpoint=checkpoint,
|
30 |
+
image_size=image_size,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
build_sam = build_sam_vit_h
|
35 |
+
|
36 |
+
|
37 |
+
def build_sam_vit_l(checkpoint=None, image_size=1024):
|
38 |
+
return _build_sam(
|
39 |
+
encoder_embed_dim=1024,
|
40 |
+
encoder_depth=24,
|
41 |
+
encoder_num_heads=16,
|
42 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
43 |
+
checkpoint=checkpoint,
|
44 |
+
image_size=image_size,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def build_sam_vit_b(checkpoint=None, image_size=1024):
|
49 |
+
return _build_sam(
|
50 |
+
encoder_embed_dim=768,
|
51 |
+
encoder_depth=12,
|
52 |
+
encoder_num_heads=12,
|
53 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
54 |
+
checkpoint=checkpoint,
|
55 |
+
image_size=image_size,
|
56 |
+
)
|
57 |
+
"""
|
58 |
+
Examples::
|
59 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
60 |
+
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
61 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
62 |
+
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
63 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
64 |
+
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
65 |
+
"""
|
66 |
+
|
67 |
+
def build_sam_vit_swin(checkpoint=None, image_size=96):
|
68 |
+
print('==> build_sam_vit_swin')
|
69 |
+
return _build_sam(
|
70 |
+
encoder_embed_dim=48,
|
71 |
+
encoder_depth=12,
|
72 |
+
encoder_num_heads=12,
|
73 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
74 |
+
checkpoint=checkpoint,
|
75 |
+
image_size=image_size,
|
76 |
+
)
|
77 |
+
|
78 |
+
sam_model_registry = {
|
79 |
+
"default": build_sam_vit_h,
|
80 |
+
"vit_h": build_sam_vit_h,
|
81 |
+
"vit_l": build_sam_vit_l,
|
82 |
+
"vit_b": build_sam_vit_b,
|
83 |
+
"swin_vit": build_sam_vit_swin,
|
84 |
+
}
|
85 |
+
|
86 |
+
|
87 |
+
def _build_sam(
|
88 |
+
encoder_embed_dim,
|
89 |
+
encoder_depth,
|
90 |
+
encoder_num_heads,
|
91 |
+
encoder_global_attn_indexes,
|
92 |
+
checkpoint=None,
|
93 |
+
image_size=None,
|
94 |
+
spatial_dims=3,
|
95 |
+
):
|
96 |
+
prompt_embed_dim = 768
|
97 |
+
patch_size = ensure_tuple_rep(2, spatial_dims)
|
98 |
+
window_size = ensure_tuple_rep(7, spatial_dims)
|
99 |
+
image_embedding_size = [size // 32 for size in image_size]
|
100 |
+
sam = Sam(
|
101 |
+
image_encoder=SwinTransformer(
|
102 |
+
in_chans=1,
|
103 |
+
embed_dim=encoder_embed_dim,
|
104 |
+
window_size=window_size,
|
105 |
+
patch_size=patch_size,
|
106 |
+
depths=(2, 2, 6, 2), #(2, 2, 6, 2),
|
107 |
+
num_heads=(3, 6, 12, 24),
|
108 |
+
mlp_ratio=4.0,
|
109 |
+
qkv_bias=True,
|
110 |
+
spatial_dims=spatial_dims,
|
111 |
+
),
|
112 |
+
prompt_encoder=PromptEncoder(
|
113 |
+
embed_dim=prompt_embed_dim,
|
114 |
+
image_embedding_size=image_embedding_size,
|
115 |
+
input_image_size=image_size,
|
116 |
+
mask_in_chans=16,
|
117 |
+
),
|
118 |
+
mask_decoder=MaskDecoder(
|
119 |
+
num_multimask_outputs=3,
|
120 |
+
transformer=TwoWayTransformer(
|
121 |
+
depth=2,
|
122 |
+
embedding_dim=prompt_embed_dim,
|
123 |
+
mlp_dim=2048,
|
124 |
+
num_heads=8,
|
125 |
+
),
|
126 |
+
transformer_dim=prompt_embed_dim,
|
127 |
+
iou_head_depth=3,
|
128 |
+
iou_head_hidden_dim=256,
|
129 |
+
),
|
130 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
131 |
+
pixel_std=[58.395, 57.12, 57.375],
|
132 |
+
)
|
133 |
+
sam.eval()
|
134 |
+
if checkpoint is not None:
|
135 |
+
checkpoint = Path(checkpoint)
|
136 |
+
if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
|
137 |
+
cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
|
138 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
139 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
140 |
+
print("Downloading SAM ViT-B checkpoint...")
|
141 |
+
urllib.request.urlretrieve(
|
142 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
143 |
+
checkpoint,
|
144 |
+
)
|
145 |
+
print(checkpoint.name, " is downloaded!")
|
146 |
+
elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
|
147 |
+
cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
|
148 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
149 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
150 |
+
print("Downloading SAM ViT-H checkpoint...")
|
151 |
+
urllib.request.urlretrieve(
|
152 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
153 |
+
checkpoint,
|
154 |
+
)
|
155 |
+
print(checkpoint.name, " is downloaded!")
|
156 |
+
elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
|
157 |
+
cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
|
158 |
+
if len(cmd) == 0 or cmd.lower() == 'y':
|
159 |
+
checkpoint.parent.mkdir(parents=True, exist_ok=True)
|
160 |
+
print("Downloading SAM ViT-L checkpoint...")
|
161 |
+
urllib.request.urlretrieve(
|
162 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
163 |
+
checkpoint,
|
164 |
+
)
|
165 |
+
print(checkpoint.name, " is downloaded!")
|
166 |
+
|
167 |
+
|
168 |
+
if checkpoint is not None:
|
169 |
+
with open(checkpoint, "rb") as f:
|
170 |
+
state_dict = torch.load(f)
|
171 |
+
sam.load_state_dict(state_dict)
|
172 |
+
return sam
|
model/segment_anything_volumetric/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam_vit_3d,
|
9 |
+
sam_model_registry,
|
10 |
+
)
|
11 |
+
from .predictor import SamPredictor
|
12 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (407 Bytes). View file
|
|
model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (377 Bytes). View file
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc
ADDED
Binary file (11.4 kB). View file
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc
ADDED
Binary file (3.3 kB). View file
|
|
model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc
ADDED
Binary file (2.62 kB). View file
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc
ADDED
Binary file (9.96 kB). View file
|
|
model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc
ADDED
Binary file (9.98 kB). View file
|
|
model/segment_anything_volumetric/automatic_mask_generator.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from .modeling import Sam
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class SamAutomaticMaskGenerator:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
points_per_side: Optional[int] = 32,
|
40 |
+
points_per_batch: int = 64,
|
41 |
+
pred_iou_thresh: float = 0.88,
|
42 |
+
stability_score_thresh: float = 0.95,
|
43 |
+
stability_score_offset: float = 1.0,
|
44 |
+
box_nms_thresh: float = 0.7,
|
45 |
+
crop_n_layers: int = 0,
|
46 |
+
crop_nms_thresh: float = 0.7,
|
47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
48 |
+
crop_n_points_downscale_factor: int = 1,
|
49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
50 |
+
min_mask_region_area: int = 0,
|
51 |
+
output_mode: str = "binary_mask",
|
52 |
+
) -> None:
|
53 |
+
"""
|
54 |
+
Using a SAM model, generates masks for the entire image.
|
55 |
+
Generates a grid of point prompts over the image, then filters
|
56 |
+
low quality and duplicate masks. The default settings are chosen
|
57 |
+
for SAM with a ViT-H backbone.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
model (Sam): The SAM model to use for mask prediction.
|
61 |
+
points_per_side (int or None): The number of points to be sampled
|
62 |
+
along one side of the image. The total number of points is
|
63 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
64 |
+
point sampling.
|
65 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
66 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
67 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
68 |
+
model's predicted mask quality.
|
69 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
70 |
+
the stability of the mask under changes to the cutoff used to binarize
|
71 |
+
the model's mask predictions.
|
72 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
73 |
+
calculated the stability score.
|
74 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
75 |
+
suppression to filter duplicate masks.
|
76 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
77 |
+
crops of the image. Sets the number of layers to run, where each
|
78 |
+
layer has 2**i_layer number of image crops.
|
79 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
80 |
+
suppression to filter duplicate masks between different crops.
|
81 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
82 |
+
In the first crop layer, crops will overlap by this fraction of
|
83 |
+
the image length. Later layers with more crops scale down this overlap.
|
84 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
85 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
86 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
87 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
88 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
89 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
90 |
+
to remove disconnected regions and holes in masks with area smaller
|
91 |
+
than min_mask_region_area. Requires opencv.
|
92 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
93 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
94 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
95 |
+
memory.
|
96 |
+
"""
|
97 |
+
|
98 |
+
assert (points_per_side is None) != (
|
99 |
+
point_grids is None
|
100 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
101 |
+
if points_per_side is not None:
|
102 |
+
self.point_grids = build_all_layer_point_grids(
|
103 |
+
points_per_side,
|
104 |
+
crop_n_layers,
|
105 |
+
crop_n_points_downscale_factor,
|
106 |
+
)
|
107 |
+
elif point_grids is not None:
|
108 |
+
self.point_grids = point_grids
|
109 |
+
else:
|
110 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
111 |
+
|
112 |
+
assert output_mode in [
|
113 |
+
"binary_mask",
|
114 |
+
"uncompressed_rle",
|
115 |
+
"coco_rle",
|
116 |
+
], f"Unknown output_mode {output_mode}."
|
117 |
+
if output_mode == "coco_rle":
|
118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
119 |
+
|
120 |
+
if min_mask_region_area > 0:
|
121 |
+
import cv2 # type: ignore # noqa: F401
|
122 |
+
|
123 |
+
self.predictor = SamPredictor(model)
|
124 |
+
self.points_per_batch = points_per_batch
|
125 |
+
self.pred_iou_thresh = pred_iou_thresh
|
126 |
+
self.stability_score_thresh = stability_score_thresh
|
127 |
+
self.stability_score_offset = stability_score_offset
|
128 |
+
self.box_nms_thresh = box_nms_thresh
|
129 |
+
self.crop_n_layers = crop_n_layers
|
130 |
+
self.crop_nms_thresh = crop_nms_thresh
|
131 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
132 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
133 |
+
self.min_mask_region_area = min_mask_region_area
|
134 |
+
self.output_mode = output_mode
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
138 |
+
"""
|
139 |
+
Generates masks for the given image.
|
140 |
+
|
141 |
+
Arguments:
|
142 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
146 |
+
a dict containing the following keys:
|
147 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
148 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
149 |
+
is a dictionary containing the RLE.
|
150 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
151 |
+
area (int): The area in pixels of the mask.
|
152 |
+
predicted_iou (float): The model's own prediction of the mask's
|
153 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
154 |
+
point_coords (list(list(float))): The point coordinates input
|
155 |
+
to the model to generate this mask.
|
156 |
+
stability_score (float): A measure of the mask's quality. This
|
157 |
+
is filtered on using the stability_score_thresh parameter.
|
158 |
+
crop_box (list(float)): The crop of the image used to generate
|
159 |
+
the mask, given in XYWH format.
|
160 |
+
"""
|
161 |
+
|
162 |
+
# Generate masks
|
163 |
+
mask_data = self._generate_masks(image)
|
164 |
+
|
165 |
+
# Filter small disconnected regions and holes in masks
|
166 |
+
if self.min_mask_region_area > 0:
|
167 |
+
mask_data = self.postprocess_small_regions(
|
168 |
+
mask_data,
|
169 |
+
self.min_mask_region_area,
|
170 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
171 |
+
)
|
172 |
+
|
173 |
+
# Encode masks
|
174 |
+
if self.output_mode == "coco_rle":
|
175 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
176 |
+
elif self.output_mode == "binary_mask":
|
177 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
178 |
+
else:
|
179 |
+
mask_data["segmentations"] = mask_data["rles"]
|
180 |
+
|
181 |
+
# Write mask records
|
182 |
+
curr_anns = []
|
183 |
+
for idx in range(len(mask_data["segmentations"])):
|
184 |
+
ann = {
|
185 |
+
"segmentation": mask_data["segmentations"][idx],
|
186 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
187 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
188 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
189 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
190 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
191 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
192 |
+
}
|
193 |
+
curr_anns.append(ann)
|
194 |
+
|
195 |
+
return curr_anns
|
196 |
+
|
197 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
198 |
+
orig_size = image.shape[:2]
|
199 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
200 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
201 |
+
)
|
202 |
+
|
203 |
+
# Iterate over image crops
|
204 |
+
data = MaskData()
|
205 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
206 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
207 |
+
data.cat(crop_data)
|
208 |
+
|
209 |
+
# Remove duplicate masks between crops
|
210 |
+
if len(crop_boxes) > 1:
|
211 |
+
# Prefer masks from smaller crops
|
212 |
+
scores = 1 / box_area(data["crop_boxes"])
|
213 |
+
scores = scores.to(data["boxes"].device)
|
214 |
+
keep_by_nms = batched_nms(
|
215 |
+
data["boxes"].float(),
|
216 |
+
scores,
|
217 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
218 |
+
iou_threshold=self.crop_nms_thresh,
|
219 |
+
)
|
220 |
+
data.filter(keep_by_nms)
|
221 |
+
|
222 |
+
data.to_numpy()
|
223 |
+
return data
|
224 |
+
|
225 |
+
def _process_crop(
|
226 |
+
self,
|
227 |
+
image: np.ndarray,
|
228 |
+
crop_box: List[int],
|
229 |
+
crop_layer_idx: int,
|
230 |
+
orig_size: Tuple[int, ...],
|
231 |
+
) -> MaskData:
|
232 |
+
# Crop the image and calculate embeddings
|
233 |
+
x0, y0, x1, y1 = crop_box
|
234 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
235 |
+
cropped_im_size = cropped_im.shape[:2]
|
236 |
+
self.predictor.set_image(cropped_im)
|
237 |
+
|
238 |
+
# Get points for this crop
|
239 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
240 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
241 |
+
|
242 |
+
# Generate masks for this crop in batches
|
243 |
+
data = MaskData()
|
244 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
245 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
246 |
+
data.cat(batch_data)
|
247 |
+
del batch_data
|
248 |
+
self.predictor.reset_image()
|
249 |
+
|
250 |
+
# Remove duplicates within this crop.
|
251 |
+
keep_by_nms = batched_nms(
|
252 |
+
data["boxes"].float(),
|
253 |
+
data["iou_preds"],
|
254 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
255 |
+
iou_threshold=self.box_nms_thresh,
|
256 |
+
)
|
257 |
+
data.filter(keep_by_nms)
|
258 |
+
|
259 |
+
# Return to the original image frame
|
260 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
261 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
262 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
263 |
+
|
264 |
+
return data
|
265 |
+
|
266 |
+
def _process_batch(
|
267 |
+
self,
|
268 |
+
points: np.ndarray,
|
269 |
+
im_size: Tuple[int, ...],
|
270 |
+
crop_box: List[int],
|
271 |
+
orig_size: Tuple[int, ...],
|
272 |
+
) -> MaskData:
|
273 |
+
orig_h, orig_w = orig_size
|
274 |
+
|
275 |
+
# Run model on this batch
|
276 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
277 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
278 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
279 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
280 |
+
in_points[:, None, :],
|
281 |
+
in_labels[:, None],
|
282 |
+
multimask_output=True,
|
283 |
+
return_logits=True,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Serialize predictions and store in MaskData
|
287 |
+
data = MaskData(
|
288 |
+
masks=masks.flatten(0, 1),
|
289 |
+
iou_preds=iou_preds.flatten(0, 1),
|
290 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
291 |
+
)
|
292 |
+
del masks
|
293 |
+
|
294 |
+
# Filter by predicted IoU
|
295 |
+
if self.pred_iou_thresh > 0.0:
|
296 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
297 |
+
data.filter(keep_mask)
|
298 |
+
|
299 |
+
# Calculate stability score
|
300 |
+
data["stability_score"] = calculate_stability_score(
|
301 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
302 |
+
)
|
303 |
+
if self.stability_score_thresh > 0.0:
|
304 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
305 |
+
data.filter(keep_mask)
|
306 |
+
|
307 |
+
# Threshold masks and calculate boxes
|
308 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
309 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
310 |
+
|
311 |
+
# Filter boxes that touch crop boundaries
|
312 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
313 |
+
if not torch.all(keep_mask):
|
314 |
+
data.filter(keep_mask)
|
315 |
+
|
316 |
+
# Compress to RLE
|
317 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
318 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
319 |
+
del data["masks"]
|
320 |
+
|
321 |
+
return data
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def postprocess_small_regions(
|
325 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
326 |
+
) -> MaskData:
|
327 |
+
"""
|
328 |
+
Removes small disconnected regions and holes in masks, then reruns
|
329 |
+
box NMS to remove any new duplicates.
|
330 |
+
|
331 |
+
Edits mask_data in place.
|
332 |
+
|
333 |
+
Requires open-cv as a dependency.
|
334 |
+
"""
|
335 |
+
if len(mask_data["rles"]) == 0:
|
336 |
+
return mask_data
|
337 |
+
|
338 |
+
# Filter small disconnected regions and holes
|
339 |
+
new_masks = []
|
340 |
+
scores = []
|
341 |
+
for rle in mask_data["rles"]:
|
342 |
+
mask = rle_to_mask(rle)
|
343 |
+
|
344 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
345 |
+
unchanged = not changed
|
346 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
347 |
+
unchanged = unchanged and not changed
|
348 |
+
|
349 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
350 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
351 |
+
# so NMS will prefer ones that didn't need postprocessing
|
352 |
+
scores.append(float(unchanged))
|
353 |
+
|
354 |
+
# Recalculate boxes and remove any new duplicates
|
355 |
+
masks = torch.cat(new_masks, dim=0)
|
356 |
+
boxes = batched_mask_to_box(masks)
|
357 |
+
keep_by_nms = batched_nms(
|
358 |
+
boxes.float(),
|
359 |
+
torch.as_tensor(scores),
|
360 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
361 |
+
iou_threshold=nms_thresh,
|
362 |
+
)
|
363 |
+
|
364 |
+
# Only recalculate RLEs for masks that have changed
|
365 |
+
for i_mask in keep_by_nms:
|
366 |
+
if scores[i_mask] == 0.0:
|
367 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
368 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
369 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
370 |
+
mask_data.filter(keep_by_nms)
|
371 |
+
|
372 |
+
return mask_data
|
model/segment_anything_volumetric/build_sam.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from functools import partial
|
7 |
+
from pathlib import Path
|
8 |
+
import urllib.request
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .modeling import (
|
12 |
+
ImageEncoderViT,
|
13 |
+
MaskDecoder,
|
14 |
+
PromptEncoder,
|
15 |
+
Sam,
|
16 |
+
TwoWayTransformer,
|
17 |
+
)
|
18 |
+
import numpy as np
|
19 |
+
from .modeling.image_encoder_swin import SwinTransformer
|
20 |
+
from monai.networks.nets import ViT
|
21 |
+
from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
|
22 |
+
|
23 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
24 |
+
|
25 |
+
|
26 |
+
"""
|
27 |
+
Examples::
|
28 |
+
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
|
29 |
+
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
|
30 |
+
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
|
31 |
+
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
|
32 |
+
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
|
33 |
+
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
|
34 |
+
"""
|
35 |
+
|
36 |
+
def build_sam_vit_3d(checkpoint=None):
|
37 |
+
print('build_sam_vit_3d...')
|
38 |
+
return _build_sam(
|
39 |
+
image_encoder_type='vit',
|
40 |
+
embed_dim = 768,
|
41 |
+
patch_size=[4,16,16],
|
42 |
+
checkpoint=checkpoint,
|
43 |
+
image_size=[32,256,256],
|
44 |
+
)
|
45 |
+
|
46 |
+
sam_model_registry = {
|
47 |
+
"vit": build_sam_vit_3d,
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def _build_sam(
|
52 |
+
image_encoder_type,
|
53 |
+
embed_dim,
|
54 |
+
patch_size,
|
55 |
+
checkpoint,
|
56 |
+
image_size,
|
57 |
+
):
|
58 |
+
mlp_dim = 3072
|
59 |
+
num_layers = 12
|
60 |
+
num_heads = 12
|
61 |
+
pos_embed = 'perceptron'
|
62 |
+
dropout_rate = 0.0
|
63 |
+
|
64 |
+
image_encoder=ViT(
|
65 |
+
in_channels=1,
|
66 |
+
img_size=image_size,
|
67 |
+
patch_size=patch_size,
|
68 |
+
hidden_size=embed_dim,
|
69 |
+
mlp_dim=mlp_dim,
|
70 |
+
num_layers=num_layers,
|
71 |
+
num_heads=num_heads,
|
72 |
+
pos_embed=pos_embed,
|
73 |
+
classification=False,
|
74 |
+
dropout_rate=dropout_rate,
|
75 |
+
)
|
76 |
+
image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
|
77 |
+
|
78 |
+
if checkpoint is not None:
|
79 |
+
with open(checkpoint, "rb") as f:
|
80 |
+
state_dict = torch.load(f, map_location='cpu')['state_dict']
|
81 |
+
encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
|
82 |
+
image_encoder.load_state_dict(encoder_dict)
|
83 |
+
print(f'===> image_encoder.load_param: {checkpoint}')
|
84 |
+
sam = Sam(
|
85 |
+
image_encoder=image_encoder,
|
86 |
+
prompt_encoder=PromptEncoder(
|
87 |
+
embed_dim=embed_dim,
|
88 |
+
image_embedding_size=image_embedding_size,
|
89 |
+
input_image_size=image_size,
|
90 |
+
mask_in_chans=16,
|
91 |
+
),
|
92 |
+
mask_decoder=MaskDecoder(
|
93 |
+
image_encoder_type=image_encoder_type,
|
94 |
+
num_multimask_outputs=3,
|
95 |
+
transformer=TwoWayTransformer(
|
96 |
+
depth=2,
|
97 |
+
embedding_dim=embed_dim,
|
98 |
+
mlp_dim=2048,
|
99 |
+
num_heads=8,
|
100 |
+
),
|
101 |
+
transformer_dim=embed_dim,
|
102 |
+
iou_head_depth=3,
|
103 |
+
iou_head_hidden_dim=256,
|
104 |
+
image_size=np.array(image_size),
|
105 |
+
patch_size=np.array(patch_size),
|
106 |
+
),
|
107 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
108 |
+
pixel_std=[58.395, 57.12, 57.375],
|
109 |
+
)
|
110 |
+
sam.eval()
|
111 |
+
return sam
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py
ADDED
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Tuple, Type, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.checkpoint as checkpoint
|
8 |
+
from torch.nn import LayerNorm
|
9 |
+
|
10 |
+
from monai.networks.blocks import MLPBlock as Mlp
|
11 |
+
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
|
12 |
+
from monai.networks.layers import DropPath, trunc_normal_
|
13 |
+
from monai.utils import ensure_tuple_rep, optional_import
|
14 |
+
|
15 |
+
rearrange, _ = optional_import("einops", name="rearrange")
|
16 |
+
|
17 |
+
def window_partition(x, window_size):
|
18 |
+
"""window partition operation based on: "Liu et al.,
|
19 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
20 |
+
<https://arxiv.org/abs/2103.14030>"
|
21 |
+
https://github.com/microsoft/Swin-Transformer
|
22 |
+
Args:
|
23 |
+
x: input tensor.
|
24 |
+
window_size: local window size.
|
25 |
+
"""
|
26 |
+
x_shape = x.size()
|
27 |
+
if len(x_shape) == 5:
|
28 |
+
b, d, h, w, c = x_shape
|
29 |
+
x = x.view(
|
30 |
+
b,
|
31 |
+
d // window_size[0],
|
32 |
+
window_size[0],
|
33 |
+
h // window_size[1],
|
34 |
+
window_size[1],
|
35 |
+
w // window_size[2],
|
36 |
+
window_size[2],
|
37 |
+
c,
|
38 |
+
)
|
39 |
+
windows = (
|
40 |
+
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
|
41 |
+
)
|
42 |
+
elif len(x_shape) == 4:
|
43 |
+
b, h, w, c = x.shape
|
44 |
+
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
|
45 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
|
46 |
+
return windows
|
47 |
+
|
48 |
+
|
49 |
+
def window_reverse(windows, window_size, dims):
|
50 |
+
"""window reverse operation based on: "Liu et al.,
|
51 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
52 |
+
<https://arxiv.org/abs/2103.14030>"
|
53 |
+
https://github.com/microsoft/Swin-Transformer
|
54 |
+
Args:
|
55 |
+
windows: windows tensor.
|
56 |
+
window_size: local window size.
|
57 |
+
dims: dimension values.
|
58 |
+
"""
|
59 |
+
if len(dims) == 4:
|
60 |
+
b, d, h, w = dims
|
61 |
+
x = windows.view(
|
62 |
+
b,
|
63 |
+
d // window_size[0],
|
64 |
+
h // window_size[1],
|
65 |
+
w // window_size[2],
|
66 |
+
window_size[0],
|
67 |
+
window_size[1],
|
68 |
+
window_size[2],
|
69 |
+
-1,
|
70 |
+
)
|
71 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
|
72 |
+
|
73 |
+
elif len(dims) == 3:
|
74 |
+
b, h, w = dims
|
75 |
+
x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
|
76 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
def get_window_size(x_size, window_size, shift_size=None):
|
81 |
+
"""Computing window size based on: "Liu et al.,
|
82 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
83 |
+
<https://arxiv.org/abs/2103.14030>"
|
84 |
+
https://github.com/microsoft/Swin-Transformer
|
85 |
+
Args:
|
86 |
+
x_size: input size.
|
87 |
+
window_size: local window size.
|
88 |
+
shift_size: window shifting size.
|
89 |
+
"""
|
90 |
+
|
91 |
+
use_window_size = list(window_size)
|
92 |
+
if shift_size is not None:
|
93 |
+
use_shift_size = list(shift_size)
|
94 |
+
for i in range(len(x_size)):
|
95 |
+
if x_size[i] <= window_size[i]:
|
96 |
+
use_window_size[i] = x_size[i]
|
97 |
+
if shift_size is not None:
|
98 |
+
use_shift_size[i] = 0
|
99 |
+
|
100 |
+
if shift_size is None:
|
101 |
+
return tuple(use_window_size)
|
102 |
+
else:
|
103 |
+
return tuple(use_window_size), tuple(use_shift_size)
|
104 |
+
|
105 |
+
|
106 |
+
class WindowAttention(nn.Module):
|
107 |
+
"""
|
108 |
+
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
|
109 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
110 |
+
<https://arxiv.org/abs/2103.14030>"
|
111 |
+
https://github.com/microsoft/Swin-Transformer
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
dim: int,
|
117 |
+
num_heads: int,
|
118 |
+
window_size: Sequence[int],
|
119 |
+
qkv_bias: bool = False,
|
120 |
+
attn_drop: float = 0.0,
|
121 |
+
proj_drop: float = 0.0,
|
122 |
+
) -> None:
|
123 |
+
"""
|
124 |
+
Args:
|
125 |
+
dim: number of feature channels.
|
126 |
+
num_heads: number of attention heads.
|
127 |
+
window_size: local window size.
|
128 |
+
qkv_bias: add a learnable bias to query, key, value.
|
129 |
+
attn_drop: attention dropout rate.
|
130 |
+
proj_drop: dropout rate of output.
|
131 |
+
"""
|
132 |
+
|
133 |
+
super().__init__()
|
134 |
+
self.dim = dim
|
135 |
+
self.window_size = window_size
|
136 |
+
self.num_heads = num_heads
|
137 |
+
head_dim = dim // num_heads
|
138 |
+
self.scale = head_dim**-0.5
|
139 |
+
mesh_args = torch.meshgrid.__kwdefaults__
|
140 |
+
|
141 |
+
if len(self.window_size) == 3:
|
142 |
+
self.relative_position_bias_table = nn.Parameter(
|
143 |
+
torch.zeros(
|
144 |
+
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
|
145 |
+
num_heads,
|
146 |
+
)
|
147 |
+
)
|
148 |
+
coords_d = torch.arange(self.window_size[0])
|
149 |
+
coords_h = torch.arange(self.window_size[1])
|
150 |
+
coords_w = torch.arange(self.window_size[2])
|
151 |
+
if mesh_args is not None:
|
152 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
|
153 |
+
else:
|
154 |
+
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
|
155 |
+
coords_flatten = torch.flatten(coords, 1)
|
156 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
157 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
158 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
159 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
160 |
+
relative_coords[:, :, 2] += self.window_size[2] - 1
|
161 |
+
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
|
162 |
+
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
|
163 |
+
elif len(self.window_size) == 2:
|
164 |
+
self.relative_position_bias_table = nn.Parameter(
|
165 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
166 |
+
)
|
167 |
+
coords_h = torch.arange(self.window_size[0])
|
168 |
+
coords_w = torch.arange(self.window_size[1])
|
169 |
+
if mesh_args is not None:
|
170 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
|
171 |
+
else:
|
172 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
|
173 |
+
coords_flatten = torch.flatten(coords, 1)
|
174 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
175 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
176 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
177 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
178 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
179 |
+
|
180 |
+
relative_position_index = relative_coords.sum(-1)
|
181 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
182 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
183 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
184 |
+
self.proj = nn.Linear(dim, dim)
|
185 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
186 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
187 |
+
self.softmax = nn.Softmax(dim=-1)
|
188 |
+
|
189 |
+
def forward(self, x, mask):
|
190 |
+
b, n, c = x.shape
|
191 |
+
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
193 |
+
q = q * self.scale
|
194 |
+
attn = q @ k.transpose(-2, -1)
|
195 |
+
relative_position_bias = self.relative_position_bias_table[
|
196 |
+
self.relative_position_index.clone()[:n, :n].reshape(-1)
|
197 |
+
].reshape(n, n, -1)
|
198 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
199 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
200 |
+
if mask is not None:
|
201 |
+
nw = mask.shape[0]
|
202 |
+
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
203 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
204 |
+
attn = self.softmax(attn)
|
205 |
+
else:
|
206 |
+
attn = self.softmax(attn)
|
207 |
+
|
208 |
+
attn = self.attn_drop(attn)
|
209 |
+
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
210 |
+
x = self.proj(x)
|
211 |
+
x = self.proj_drop(x)
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
class SwinTransformerBlock(nn.Module):
|
216 |
+
"""
|
217 |
+
Swin Transformer block based on: "Liu et al.,
|
218 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
219 |
+
<https://arxiv.org/abs/2103.14030>"
|
220 |
+
https://github.com/microsoft/Swin-Transformer
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
dim: int,
|
226 |
+
num_heads: int,
|
227 |
+
window_size: Sequence[int],
|
228 |
+
shift_size: Sequence[int],
|
229 |
+
mlp_ratio: float = 4.0,
|
230 |
+
qkv_bias: bool = True,
|
231 |
+
drop: float = 0.0,
|
232 |
+
attn_drop: float = 0.0,
|
233 |
+
drop_path: float = 0.0,
|
234 |
+
act_layer: str = "GELU",
|
235 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
236 |
+
use_checkpoint: bool = False,
|
237 |
+
) -> None:
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
dim: number of feature channels.
|
241 |
+
num_heads: number of attention heads.
|
242 |
+
window_size: local window size.
|
243 |
+
shift_size: window shift size.
|
244 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
245 |
+
qkv_bias: add a learnable bias to query, key, value.
|
246 |
+
drop: dropout rate.
|
247 |
+
attn_drop: attention dropout rate.
|
248 |
+
drop_path: stochastic depth rate.
|
249 |
+
act_layer: activation layer.
|
250 |
+
norm_layer: normalization layer.
|
251 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
252 |
+
"""
|
253 |
+
|
254 |
+
super().__init__()
|
255 |
+
self.dim = dim
|
256 |
+
self.num_heads = num_heads
|
257 |
+
self.window_size = window_size
|
258 |
+
self.shift_size = shift_size
|
259 |
+
self.mlp_ratio = mlp_ratio
|
260 |
+
self.use_checkpoint = use_checkpoint
|
261 |
+
self.norm1 = norm_layer(dim)
|
262 |
+
self.attn = WindowAttention(
|
263 |
+
dim,
|
264 |
+
window_size=self.window_size,
|
265 |
+
num_heads=num_heads,
|
266 |
+
qkv_bias=qkv_bias,
|
267 |
+
attn_drop=attn_drop,
|
268 |
+
proj_drop=drop,
|
269 |
+
)
|
270 |
+
|
271 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
272 |
+
self.norm2 = norm_layer(dim)
|
273 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
274 |
+
self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
|
275 |
+
|
276 |
+
def forward_part1(self, x, mask_matrix):
|
277 |
+
x_shape = x.size()
|
278 |
+
x = self.norm1(x)
|
279 |
+
if len(x_shape) == 5:
|
280 |
+
b, d, h, w, c = x.shape
|
281 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
282 |
+
pad_l = pad_t = pad_d0 = 0
|
283 |
+
pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
|
284 |
+
pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
|
285 |
+
pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
|
286 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
|
287 |
+
_, dp, hp, wp, _ = x.shape
|
288 |
+
dims = [b, dp, hp, wp]
|
289 |
+
|
290 |
+
elif len(x_shape) == 4:
|
291 |
+
b, h, w, c = x.shape
|
292 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
293 |
+
pad_l = pad_t = 0
|
294 |
+
pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
|
295 |
+
pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
|
296 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
297 |
+
_, hp, wp, _ = x.shape
|
298 |
+
dims = [b, hp, wp]
|
299 |
+
|
300 |
+
if any(i > 0 for i in shift_size):
|
301 |
+
if len(x_shape) == 5:
|
302 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
|
303 |
+
elif len(x_shape) == 4:
|
304 |
+
shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
305 |
+
attn_mask = mask_matrix
|
306 |
+
else:
|
307 |
+
shifted_x = x
|
308 |
+
attn_mask = None
|
309 |
+
x_windows = window_partition(shifted_x, window_size)
|
310 |
+
attn_windows = self.attn(x_windows, mask=attn_mask)
|
311 |
+
attn_windows = attn_windows.view(-1, *(window_size + (c,)))
|
312 |
+
shifted_x = window_reverse(attn_windows, window_size, dims)
|
313 |
+
if any(i > 0 for i in shift_size):
|
314 |
+
if len(x_shape) == 5:
|
315 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
|
316 |
+
elif len(x_shape) == 4:
|
317 |
+
x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
318 |
+
else:
|
319 |
+
x = shifted_x
|
320 |
+
|
321 |
+
if len(x_shape) == 5:
|
322 |
+
if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
|
323 |
+
x = x[:, :d, :h, :w, :].contiguous()
|
324 |
+
elif len(x_shape) == 4:
|
325 |
+
if pad_r > 0 or pad_b > 0:
|
326 |
+
x = x[:, :h, :w, :].contiguous()
|
327 |
+
|
328 |
+
return x
|
329 |
+
|
330 |
+
def forward_part2(self, x):
|
331 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
332 |
+
|
333 |
+
def load_from(self, weights, n_block, layer):
|
334 |
+
root = f"module.{layer}.0.blocks.{n_block}."
|
335 |
+
block_names = [
|
336 |
+
"norm1.weight",
|
337 |
+
"norm1.bias",
|
338 |
+
"attn.relative_position_bias_table",
|
339 |
+
"attn.relative_position_index",
|
340 |
+
"attn.qkv.weight",
|
341 |
+
"attn.qkv.bias",
|
342 |
+
"attn.proj.weight",
|
343 |
+
"attn.proj.bias",
|
344 |
+
"norm2.weight",
|
345 |
+
"norm2.bias",
|
346 |
+
"mlp.fc1.weight",
|
347 |
+
"mlp.fc1.bias",
|
348 |
+
"mlp.fc2.weight",
|
349 |
+
"mlp.fc2.bias",
|
350 |
+
]
|
351 |
+
with torch.no_grad():
|
352 |
+
self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
|
353 |
+
self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
|
354 |
+
self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
|
355 |
+
self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
|
356 |
+
self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
|
357 |
+
self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
|
358 |
+
self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
|
359 |
+
self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
|
360 |
+
self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
|
361 |
+
self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
|
362 |
+
self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
|
363 |
+
self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
|
364 |
+
self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
|
365 |
+
self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
|
366 |
+
|
367 |
+
def forward(self, x, mask_matrix):
|
368 |
+
shortcut = x
|
369 |
+
if self.use_checkpoint:
|
370 |
+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
|
371 |
+
else:
|
372 |
+
x = self.forward_part1(x, mask_matrix)
|
373 |
+
x = shortcut + self.drop_path(x)
|
374 |
+
if self.use_checkpoint:
|
375 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
376 |
+
else:
|
377 |
+
x = x + self.forward_part2(x)
|
378 |
+
return x
|
379 |
+
|
380 |
+
|
381 |
+
class PatchMerging(nn.Module):
|
382 |
+
"""
|
383 |
+
Patch merging layer based on: "Liu et al.,
|
384 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
385 |
+
<https://arxiv.org/abs/2103.14030>"
|
386 |
+
https://github.com/microsoft/Swin-Transformer
|
387 |
+
"""
|
388 |
+
|
389 |
+
def __init__(
|
390 |
+
self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
|
391 |
+
) -> None: # type: ignore
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
dim: number of feature channels.
|
395 |
+
norm_layer: normalization layer.
|
396 |
+
spatial_dims: number of spatial dims.
|
397 |
+
"""
|
398 |
+
|
399 |
+
super().__init__()
|
400 |
+
self.dim = dim
|
401 |
+
if spatial_dims == 3:
|
402 |
+
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
|
403 |
+
self.norm = norm_layer(8 * dim)
|
404 |
+
elif spatial_dims == 2:
|
405 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
406 |
+
self.norm = norm_layer(4 * dim)
|
407 |
+
|
408 |
+
def forward(self, x):
|
409 |
+
|
410 |
+
x_shape = x.size()
|
411 |
+
if len(x_shape) == 5:
|
412 |
+
b, d, h, w, c = x_shape
|
413 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
|
414 |
+
if pad_input:
|
415 |
+
x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
|
416 |
+
x0 = x[:, 0::2, 0::2, 0::2, :]
|
417 |
+
x1 = x[:, 1::2, 0::2, 0::2, :]
|
418 |
+
x2 = x[:, 0::2, 1::2, 0::2, :]
|
419 |
+
x3 = x[:, 0::2, 0::2, 1::2, :]
|
420 |
+
x4 = x[:, 1::2, 0::2, 1::2, :]
|
421 |
+
x5 = x[:, 0::2, 1::2, 0::2, :]
|
422 |
+
x6 = x[:, 0::2, 0::2, 1::2, :]
|
423 |
+
x7 = x[:, 1::2, 1::2, 1::2, :]
|
424 |
+
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
|
425 |
+
|
426 |
+
elif len(x_shape) == 4:
|
427 |
+
b, h, w, c = x_shape
|
428 |
+
pad_input = (h % 2 == 1) or (w % 2 == 1)
|
429 |
+
if pad_input:
|
430 |
+
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
|
431 |
+
x0 = x[:, 0::2, 0::2, :]
|
432 |
+
x1 = x[:, 1::2, 0::2, :]
|
433 |
+
x2 = x[:, 0::2, 1::2, :]
|
434 |
+
x3 = x[:, 1::2, 1::2, :]
|
435 |
+
x = torch.cat([x0, x1, x2, x3], -1)
|
436 |
+
|
437 |
+
x = self.norm(x)
|
438 |
+
x = self.reduction(x)
|
439 |
+
return x
|
440 |
+
|
441 |
+
|
442 |
+
def compute_mask(dims, window_size, shift_size, device):
|
443 |
+
"""Computing region masks based on: "Liu et al.,
|
444 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
445 |
+
<https://arxiv.org/abs/2103.14030>"
|
446 |
+
https://github.com/microsoft/Swin-Transformer
|
447 |
+
Args:
|
448 |
+
dims: dimension values.
|
449 |
+
window_size: local window size.
|
450 |
+
shift_size: shift size.
|
451 |
+
device: device.
|
452 |
+
"""
|
453 |
+
|
454 |
+
cnt = 0
|
455 |
+
|
456 |
+
if len(dims) == 3:
|
457 |
+
d, h, w = dims
|
458 |
+
img_mask = torch.zeros((1, d, h, w, 1), device=device)
|
459 |
+
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
460 |
+
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
461 |
+
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
|
462 |
+
img_mask[:, d, h, w, :] = cnt
|
463 |
+
cnt += 1
|
464 |
+
|
465 |
+
elif len(dims) == 2:
|
466 |
+
h, w = dims
|
467 |
+
img_mask = torch.zeros((1, h, w, 1), device=device)
|
468 |
+
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
|
469 |
+
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
|
470 |
+
img_mask[:, h, w, :] = cnt
|
471 |
+
cnt += 1
|
472 |
+
|
473 |
+
mask_windows = window_partition(img_mask, window_size)
|
474 |
+
mask_windows = mask_windows.squeeze(-1)
|
475 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
476 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
477 |
+
|
478 |
+
return attn_mask
|
479 |
+
|
480 |
+
|
481 |
+
class BasicLayer(nn.Module):
|
482 |
+
"""
|
483 |
+
Basic Swin Transformer layer in one stage based on: "Liu et al.,
|
484 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
485 |
+
<https://arxiv.org/abs/2103.14030>"
|
486 |
+
https://github.com/microsoft/Swin-Transformer
|
487 |
+
"""
|
488 |
+
|
489 |
+
def __init__(
|
490 |
+
self,
|
491 |
+
dim: int,
|
492 |
+
depth: int,
|
493 |
+
num_heads: int,
|
494 |
+
window_size: Sequence[int],
|
495 |
+
drop_path: list,
|
496 |
+
mlp_ratio: float = 4.0,
|
497 |
+
qkv_bias: bool = False,
|
498 |
+
drop: float = 0.0,
|
499 |
+
attn_drop: float = 0.0,
|
500 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
501 |
+
downsample: isinstance = None, # type: ignore
|
502 |
+
use_checkpoint: bool = False,
|
503 |
+
) -> None:
|
504 |
+
"""
|
505 |
+
Args:
|
506 |
+
dim: number of feature channels.
|
507 |
+
depths: number of layers in each stage.
|
508 |
+
num_heads: number of attention heads.
|
509 |
+
window_size: local window size.
|
510 |
+
drop_path: stochastic depth rate.
|
511 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
512 |
+
qkv_bias: add a learnable bias to query, key, value.
|
513 |
+
drop: dropout rate.
|
514 |
+
attn_drop: attention dropout rate.
|
515 |
+
norm_layer: normalization layer.
|
516 |
+
downsample: downsample layer at the end of the layer.
|
517 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
518 |
+
"""
|
519 |
+
|
520 |
+
super().__init__()
|
521 |
+
self.window_size = window_size
|
522 |
+
self.shift_size = tuple(i // 2 for i in window_size)
|
523 |
+
self.no_shift = tuple(0 for i in window_size)
|
524 |
+
self.depth = depth
|
525 |
+
self.use_checkpoint = use_checkpoint
|
526 |
+
self.blocks = nn.ModuleList(
|
527 |
+
[
|
528 |
+
SwinTransformerBlock(
|
529 |
+
dim=dim,
|
530 |
+
num_heads=num_heads,
|
531 |
+
window_size=self.window_size,
|
532 |
+
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
|
533 |
+
mlp_ratio=mlp_ratio,
|
534 |
+
qkv_bias=qkv_bias,
|
535 |
+
drop=drop,
|
536 |
+
attn_drop=attn_drop,
|
537 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
538 |
+
norm_layer=norm_layer,
|
539 |
+
use_checkpoint=use_checkpoint,
|
540 |
+
)
|
541 |
+
for i in range(depth)
|
542 |
+
]
|
543 |
+
)
|
544 |
+
self.downsample = downsample
|
545 |
+
if self.downsample is not None:
|
546 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
|
547 |
+
|
548 |
+
def forward(self, x):
|
549 |
+
x_shape = x.size()
|
550 |
+
if len(x_shape) == 5:
|
551 |
+
b, c, d, h, w = x_shape
|
552 |
+
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
|
553 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
554 |
+
dp = int(np.ceil(d / window_size[0])) * window_size[0]
|
555 |
+
hp = int(np.ceil(h / window_size[1])) * window_size[1]
|
556 |
+
wp = int(np.ceil(w / window_size[2])) * window_size[2]
|
557 |
+
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
|
558 |
+
for blk in self.blocks:
|
559 |
+
x = blk(x, attn_mask)
|
560 |
+
x = x.view(b, d, h, w, -1)
|
561 |
+
if self.downsample is not None:
|
562 |
+
x = self.downsample(x)
|
563 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
564 |
+
|
565 |
+
elif len(x_shape) == 4:
|
566 |
+
b, c, h, w = x_shape
|
567 |
+
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
|
568 |
+
x = rearrange(x, "b c h w -> b h w c")
|
569 |
+
hp = int(np.ceil(h / window_size[0])) * window_size[0]
|
570 |
+
wp = int(np.ceil(w / window_size[1])) * window_size[1]
|
571 |
+
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
|
572 |
+
for blk in self.blocks:
|
573 |
+
x = blk(x, attn_mask)
|
574 |
+
x = x.view(b, h, w, -1)
|
575 |
+
if self.downsample is not None:
|
576 |
+
x = self.downsample(x)
|
577 |
+
x = rearrange(x, "b h w c -> b c h w")
|
578 |
+
return x
|
579 |
+
|
580 |
+
|
581 |
+
class SwinTransformer(nn.Module):
|
582 |
+
"""
|
583 |
+
Swin Transformer based on: "Liu et al.,
|
584 |
+
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
|
585 |
+
<https://arxiv.org/abs/2103.14030>"
|
586 |
+
https://github.com/microsoft/Swin-Transformer
|
587 |
+
"""
|
588 |
+
|
589 |
+
def __init__(
|
590 |
+
self,
|
591 |
+
in_chans: int,
|
592 |
+
embed_dim: int,
|
593 |
+
window_size: Sequence[int],
|
594 |
+
patch_size: Sequence[int],
|
595 |
+
depths: Sequence[int],
|
596 |
+
num_heads: Sequence[int],
|
597 |
+
mlp_ratio: float = 4.0,
|
598 |
+
qkv_bias: bool = True,
|
599 |
+
drop_rate: float = 0.0,
|
600 |
+
attn_drop_rate: float = 0.0,
|
601 |
+
drop_path_rate: float = 0.0,
|
602 |
+
norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
|
603 |
+
patch_norm: bool = False,
|
604 |
+
use_checkpoint: bool = False,
|
605 |
+
spatial_dims: int = 3,
|
606 |
+
) -> None:
|
607 |
+
"""
|
608 |
+
Args:
|
609 |
+
in_chans: dimension of input channels.
|
610 |
+
embed_dim: number of linear projection output channels.
|
611 |
+
window_size: local window size.
|
612 |
+
patch_size: patch size.
|
613 |
+
depths: number of layers in each stage.
|
614 |
+
num_heads: number of attention heads.
|
615 |
+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
|
616 |
+
qkv_bias: add a learnable bias to query, key, value.
|
617 |
+
drop_rate: dropout rate.
|
618 |
+
attn_drop_rate: attention dropout rate.
|
619 |
+
drop_path_rate: stochastic depth rate.
|
620 |
+
norm_layer: normalization layer.
|
621 |
+
patch_norm: add normalization after patch embedding.
|
622 |
+
use_checkpoint: use gradient checkpointing for reduced memory usage.
|
623 |
+
spatial_dims: spatial dimension.
|
624 |
+
"""
|
625 |
+
|
626 |
+
super().__init__()
|
627 |
+
self.num_layers = len(depths)
|
628 |
+
self.embed_dim = embed_dim
|
629 |
+
self.patch_norm = patch_norm
|
630 |
+
self.window_size = window_size
|
631 |
+
self.patch_size = patch_size
|
632 |
+
self.patch_embed = PatchEmbed(
|
633 |
+
patch_size=self.patch_size,
|
634 |
+
in_chans=in_chans,
|
635 |
+
embed_dim=embed_dim,
|
636 |
+
norm_layer=norm_layer if self.patch_norm else None, # type: ignore
|
637 |
+
spatial_dims=spatial_dims,
|
638 |
+
)
|
639 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
640 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
641 |
+
# self.layers1 = nn.ModuleList()
|
642 |
+
# self.layers2 = nn.ModuleList()
|
643 |
+
# self.layers3 = nn.ModuleList()
|
644 |
+
# self.layers4 = nn.ModuleList()
|
645 |
+
self.layers = nn.ModuleList()
|
646 |
+
for i_layer in range(self.num_layers):
|
647 |
+
layer = BasicLayer(
|
648 |
+
dim=int(embed_dim * 2**i_layer),
|
649 |
+
depth=depths[i_layer],
|
650 |
+
num_heads=num_heads[i_layer],
|
651 |
+
window_size=self.window_size,
|
652 |
+
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
653 |
+
mlp_ratio=mlp_ratio,
|
654 |
+
qkv_bias=qkv_bias,
|
655 |
+
drop=drop_rate,
|
656 |
+
attn_drop=attn_drop_rate,
|
657 |
+
norm_layer=norm_layer,
|
658 |
+
downsample=PatchMerging,
|
659 |
+
use_checkpoint=use_checkpoint,
|
660 |
+
)
|
661 |
+
self.layers.append(layer)
|
662 |
+
# if i_layer == 0:
|
663 |
+
# self.layers1.append(layer)
|
664 |
+
# elif i_layer == 1:
|
665 |
+
# self.layers2.append(layer)
|
666 |
+
# elif i_layer == 2:
|
667 |
+
# self.layers3.append(layer)
|
668 |
+
# elif i_layer == 3:
|
669 |
+
# self.layers4.append(layer)
|
670 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
671 |
+
|
672 |
+
def proj_out(self, x, normalize=False):
|
673 |
+
if normalize:
|
674 |
+
x_shape = x.size()
|
675 |
+
if len(x_shape) == 5:
|
676 |
+
n, ch, d, h, w = x_shape
|
677 |
+
x = rearrange(x, "n c d h w -> n d h w c")
|
678 |
+
x = F.layer_norm(x, [ch])
|
679 |
+
x = rearrange(x, "n d h w c -> n c d h w")
|
680 |
+
elif len(x_shape) == 4:
|
681 |
+
n, ch, h, w = x_shape
|
682 |
+
x = rearrange(x, "n c h w -> n h w c")
|
683 |
+
x = F.layer_norm(x, [ch])
|
684 |
+
x = rearrange(x, "n h w c -> n c h w")
|
685 |
+
return x
|
686 |
+
|
687 |
+
def forward(self, x, normalize=True):
|
688 |
+
# x input: [B*sample, C(1), H, W, D]
|
689 |
+
# x = rearrange(x, "b c h w d -> b c d h w")
|
690 |
+
# print('>> input: ', x.shape)
|
691 |
+
x = self.patch_embed(x)
|
692 |
+
# print('>> patch_embed: ', x.shape)
|
693 |
+
x = self.pos_drop(x)
|
694 |
+
for layer in self.layers:
|
695 |
+
x = layer(x.contiguous())
|
696 |
+
# print('>> layer: ', x.shape)
|
697 |
+
return x
|
698 |
+
# # x0_out = self.proj_out(x0, normalize)
|
699 |
+
# x1 = self.layers1[0](x0.contiguous())
|
700 |
+
# # x1_out = self.proj_out(x1, normalize)
|
701 |
+
# x2 = self.layers2[0](x1.contiguous())
|
702 |
+
# # x2_out = self.proj_out(x2, normalize)
|
703 |
+
# x3 = self.layers3[0](x2.contiguous())
|
704 |
+
# # x3_out = self.proj_out(x3, normalize)
|
705 |
+
# x4 = self.layers4[0](x3.contiguous())
|
706 |
+
# # x4_out = self.proj_out(x4, normalize)
|
707 |
+
# # return [x0_out, x1_out, x2_out, x3_out, x4_out]
|
708 |
+
|
709 |
+
|
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from typing import Any, Optional, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d
|
14 |
+
import os
|
15 |
+
|
16 |
+
class PromptEncoder(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
embed_dim: int,
|
20 |
+
image_embedding_size: Tuple[int, int, int],
|
21 |
+
input_image_size: Tuple[int, int, int],
|
22 |
+
mask_in_chans: int,
|
23 |
+
activation: Type[nn.Module] = nn.GELU,
|
24 |
+
) -> None:
|
25 |
+
"""
|
26 |
+
Encodes prompts for input to SAM's mask decoder.
|
27 |
+
|
28 |
+
Arguments:
|
29 |
+
embed_dim (int): The prompts' embedding dimension
|
30 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
31 |
+
image embedding, as (H, W).
|
32 |
+
input_image_size (int): The padded size of the image as input
|
33 |
+
to the image encoder, as (H, W).
|
34 |
+
mask_in_chans (int): The number of hidden channels used for
|
35 |
+
encoding input masks.
|
36 |
+
activation (nn.Module): The activation to use when encoding
|
37 |
+
input masks.
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
self.embed_dim = embed_dim
|
41 |
+
self.input_image_size = input_image_size
|
42 |
+
self.image_embedding_size = image_embedding_size
|
43 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
44 |
+
|
45 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
46 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
47 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
48 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
49 |
+
|
50 |
+
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
|
51 |
+
self.mask_downscaling = nn.Sequential(
|
52 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
53 |
+
LayerNorm2d(mask_in_chans // 4),
|
54 |
+
activation(),
|
55 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
56 |
+
LayerNorm2d(mask_in_chans),
|
57 |
+
activation(),
|
58 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
59 |
+
)
|
60 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
61 |
+
|
62 |
+
def get_dense_pe(self) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Returns the positional encoding used to encode point prompts,
|
65 |
+
applied to a dense set of points the shape of the image encoding.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Positional encoding with shape
|
69 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
70 |
+
"""
|
71 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
72 |
+
|
73 |
+
def _embed_points(
|
74 |
+
self,
|
75 |
+
points: torch.Tensor,
|
76 |
+
labels: torch.Tensor,
|
77 |
+
pad: bool,
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Embeds point prompts."""
|
80 |
+
points = points + 0.5 # Shift to center of pixel
|
81 |
+
if pad:
|
82 |
+
padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
|
83 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
84 |
+
points = torch.cat([points, padding_point], dim=1)
|
85 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
86 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
87 |
+
point_embedding[labels == -1] = 0.0
|
88 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
89 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
90 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
91 |
+
return point_embedding
|
92 |
+
|
93 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
94 |
+
"""Embeds box prompts."""
|
95 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
96 |
+
coords = boxes.reshape(-1, 2, 3)
|
97 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
98 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
99 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
100 |
+
return corner_embedding
|
101 |
+
|
102 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
103 |
+
"""Embeds mask inputs."""
|
104 |
+
mask_embedding = self.mask_downscaling(masks)
|
105 |
+
return mask_embedding
|
106 |
+
|
107 |
+
def _get_batch_size(
|
108 |
+
self,
|
109 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
110 |
+
boxes: Optional[torch.Tensor],
|
111 |
+
masks: Optional[torch.Tensor],
|
112 |
+
text_embedding: Optional[torch.Tensor],
|
113 |
+
) -> int:
|
114 |
+
"""
|
115 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
116 |
+
"""
|
117 |
+
if points is not None:
|
118 |
+
return points[0].shape[0]
|
119 |
+
elif boxes is not None:
|
120 |
+
return boxes.shape[0]
|
121 |
+
elif masks is not None:
|
122 |
+
return masks.shape[0]
|
123 |
+
elif text_embedding is not None:
|
124 |
+
return text_embedding.shape[0]
|
125 |
+
else:
|
126 |
+
return 1
|
127 |
+
|
128 |
+
def _get_device(self) -> torch.device:
|
129 |
+
return self.point_embeddings[0].weight.device
|
130 |
+
|
131 |
+
def forward(
|
132 |
+
self,
|
133 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
134 |
+
boxes: Optional[torch.Tensor],
|
135 |
+
masks: Optional[torch.Tensor],
|
136 |
+
text_embedding: Optional[torch.Tensor],
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""
|
139 |
+
Embeds different types of prompts, returning both sparse and dense
|
140 |
+
embeddings.
|
141 |
+
|
142 |
+
Arguments:
|
143 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
144 |
+
and labels to embed.
|
145 |
+
boxes (torch.Tensor or none): boxes to embed
|
146 |
+
masks (torch.Tensor or none): masks to embed
|
147 |
+
text: test prompt (B, 768)
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
151 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
152 |
+
and boxes.
|
153 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
154 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
155 |
+
"""
|
156 |
+
# print('prompt encoder here...')
|
157 |
+
|
158 |
+
bs = self._get_batch_size(points, boxes, masks, text_embedding)
|
159 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
160 |
+
# print('sparse_embeddings ', sparse_embeddings.shape)
|
161 |
+
if points is not None:
|
162 |
+
coords, labels = points
|
163 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
164 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
165 |
+
|
166 |
+
if boxes is not None:
|
167 |
+
box_embeddings = self._embed_boxes(boxes)
|
168 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
169 |
+
|
170 |
+
if text_embedding is not None:
|
171 |
+
sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
|
172 |
+
|
173 |
+
# print('box_embeddings ', box_embeddings.shape)
|
174 |
+
# print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
|
175 |
+
|
176 |
+
if masks is not None:
|
177 |
+
dense_embeddings = self._embed_masks(masks)
|
178 |
+
else:
|
179 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
|
180 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]
|
181 |
+
)
|
182 |
+
# print('dense_embeddings ', dense_embeddings.shape)
|
183 |
+
return sparse_embeddings, dense_embeddings
|
184 |
+
|
185 |
+
|
186 |
+
class PositionEmbeddingRandom(nn.Module):
|
187 |
+
"""
|
188 |
+
Positional encoding using random spatial frequencies.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
192 |
+
super().__init__()
|
193 |
+
if scale is None or scale <= 0.0:
|
194 |
+
scale = 1.0
|
195 |
+
self.register_buffer(
|
196 |
+
"positional_encoding_gaussian_matrix",
|
197 |
+
scale * torch.randn((3, num_pos_feats)),
|
198 |
+
)
|
199 |
+
|
200 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
201 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
202 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
203 |
+
coords = 2 * coords - 1
|
204 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
205 |
+
coords = 2 * np.pi * coords
|
206 |
+
# outputs d_1 x ... x d_n x C shape
|
207 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
208 |
+
|
209 |
+
def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
|
210 |
+
"""Generate positional encoding for a grid of the specified size."""
|
211 |
+
h, w, d = size
|
212 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
213 |
+
grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
|
214 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
215 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
216 |
+
z_embed = grid.cumsum(dim=2) - 0.5
|
217 |
+
y_embed = y_embed / h
|
218 |
+
x_embed = x_embed / w
|
219 |
+
z_embed = z_embed / d
|
220 |
+
|
221 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
|
222 |
+
return pe.permute(3, 0, 1, 2) # C x H x W x D
|
223 |
+
|
224 |
+
def forward_with_coords(
|
225 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
226 |
+
) -> torch.Tensor:
|
227 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
228 |
+
coords = coords_input.clone()
|
229 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
230 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
231 |
+
coords[:, :, 2] = coords[:, :, 2] / image_size[2]
|
232 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
model/segment_anything_volumetric/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (394 Bytes). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (424 Bytes). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc
ADDED
Binary file (1.75 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc
ADDED
Binary file (1.77 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc
ADDED
Binary file (11.4 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc
ADDED
Binary file (21.5 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc
ADDED
Binary file (5.5 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-39.pyc
ADDED
Binary file (6.09 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/prompt_encoder.cpython-310.pyc
ADDED
Binary file (7.68 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/prompt_encoder.cpython-39.pyc
ADDED
Binary file (8.01 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/sam.cpython-310.pyc
ADDED
Binary file (6.66 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/sam.cpython-39.pyc
ADDED
Binary file (6.67 kB). View file
|
|
model/segment_anything_volumetric/modeling/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (6.6 kB). View file
|
|