DiGuaQiu commited on
Commit
c8ffaae
1 Parent(s): 4ba5c46

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol
2
+
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+ import nibabel as nib
7
+ import numpy as np
8
+ from PIL import ImageDraw
9
+ from streamlit_drawable_canvas import st_canvas
10
+ from streamlit_image_coordinates import streamlit_image_coordinates
11
+ import nibabel as nib
12
+ import SimpleITK as sitk
13
+ import streamlit as st
14
+ import utils
15
+ from utils import (
16
+ initial_rectangle,
17
+ make_fig,
18
+ reflect_box_into_model,
19
+ reflect_json_data_to_3D_box,
20
+ run,
21
+ )
22
+
23
+ # from viewer import BasicViewer
24
+
25
+ print("script run")
26
+ st.title("MRSegmentator")
27
+
28
+ #############################################
29
+ # init session_state
30
+ if "option" not in st.session_state:
31
+ st.session_state.option = None
32
+
33
+ if "reset_demo_case" not in st.session_state:
34
+ st.session_state.reset_demo_case = False
35
+
36
+ if "preds_3D" not in st.session_state:
37
+ st.session_state.preds_3D = None
38
+ st.session_state.preds_path = None
39
+
40
+ if "data_item" not in st.session_state:
41
+ st.session_state.data_item = None
42
+
43
+ if "rectangle_3Dbox" not in st.session_state:
44
+ st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
45
+
46
+ if "running" not in st.session_state:
47
+ st.session_state.running = False
48
+
49
+ if "transparency" not in st.session_state:
50
+ st.session_state.transparency = 0.25
51
+
52
+ case_list = [
53
+ "images/amos_0541_MRI.nii.gz",
54
+ "images/amos_0571_MRI.nii.gz",
55
+ "images/amos_0001_CT.nii.gz",
56
+ ]
57
+
58
+ #############################################
59
+
60
+
61
+ #############################################
62
+ # reset functions
63
+ def clear_prompts():
64
+ st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
65
+
66
+
67
+ def reset_demo_case():
68
+ st.session_state.data_item = None
69
+ st.session_state.reset_demo_case = True
70
+ clear_prompts()
71
+
72
+ def clear_file():
73
+ st.session_state.option = None
74
+ reset_demo_case()
75
+ clear_prompts()
76
+
77
+
78
+ #############################################
79
+
80
+ github_col, arxive_col = st.columns(2)
81
+
82
+ with github_col:
83
+ st.write("Git: https://github.com/hhaentze/mrsegmentator")
84
+
85
+ with arxive_col:
86
+ st.write("Paper: https://arxiv.org/abs/2405.06463")
87
+
88
+ # modify demo case here
89
+ demo_type = st.radio("Demo case source", ["Select", "Upload"], on_change=clear_file)
90
+
91
+ with tempfile.TemporaryDirectory() as tmpdirname:
92
+
93
+ # modify demo case here
94
+ if demo_type == "Select":
95
+ uploaded_file = st.selectbox(
96
+ "Select a demo case",
97
+ case_list,
98
+ index=None,
99
+ placeholder="Select a demo case...",
100
+ on_change=reset_demo_case,
101
+ )
102
+ else:
103
+ uploaded_file = st.file_uploader(
104
+ "Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case
105
+ )
106
+
107
+ if( uploaded_file is not None ):
108
+ with open(tmpdirname + "/" + uploaded_file.name, 'wb') as f:
109
+ f.write(uploaded_file.getvalue())
110
+ uploaded_file = tmpdirname + "/" + uploaded_file.name
111
+
112
+ st.session_state.option = uploaded_file
113
+
114
+ if (
115
+ st.session_state.option is not None
116
+ and st.session_state.reset_demo_case
117
+ or (st.session_state.data_item is None and st.session_state.option is not None)
118
+ ):
119
+
120
+ st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
121
+ st.session_state.data_item_ori = sitk.ReadImage(Path(__file__).parent / str(uploaded_file))
122
+ st.session_state.reset_demo_case = False
123
+ st.session_state.preds_3D = None
124
+ st.session_state.preds_path = None
125
+
126
+
127
+ if st.session_state.option is None:
128
+ st.write("please select demo case first")
129
+ else:
130
+ image_3D = st.session_state.data_item
131
+ px_range = st.slider( "Select intensity range",
132
+ int(image_3D.min()),
133
+ int(image_3D.max()),
134
+ (int(image_3D.min()), int(image_3D.max()))
135
+ )
136
+ col_control1, col_control2 = st.columns(2)
137
+
138
+ with col_control1:
139
+ selected_index_z = st.slider(
140
+ "Axial view", 0, image_3D.shape[0] - 1, image_3D.shape[0] // 2, key="xy", disabled=st.session_state.running
141
+ )
142
+
143
+ with col_control2:
144
+ selected_index_y = st.slider(
145
+ "Coronal view", 0, image_3D.shape[1] - 1, image_3D.shape[1] // 2, key="xz", disabled=st.session_state.running
146
+ )
147
+
148
+ col_image1, col_image2 = st.columns(2)
149
+
150
+ if st.session_state.preds_3D is not None:
151
+ st.session_state.transparency = st.slider(
152
+ "Mask opacity", 0.0, 1.0, 0.5, disabled=st.session_state.running
153
+ )
154
+
155
+ with col_image1:
156
+
157
+ image_z_array = image_3D[selected_index_z]
158
+
159
+ preds_z_array = None
160
+ if st.session_state.preds_3D is not None:
161
+ preds_z_array = st.session_state.preds_3D[selected_index_z]
162
+
163
+ image_z = make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
164
+ st.image(image_z, use_column_width=False)
165
+
166
+ with col_image2:
167
+ image_y_array = image_3D[:, selected_index_y, :]
168
+
169
+ preds_y_array = None
170
+ if st.session_state.preds_3D is not None:
171
+ preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
172
+
173
+ image_y = make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
174
+ st.image(image_y, use_column_width=False)
175
+
176
+ ######################################################
177
+
178
+ col1, col2, col3 = st.columns(3)
179
+
180
+ with col1:
181
+ if st.button(
182
+ "Clear",
183
+ use_container_width=True,
184
+ disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)),
185
+ ):
186
+ clear_prompts()
187
+ st.session_state.preds_3D = None
188
+ st.session_state.preds_path = None
189
+ st.rerun()
190
+
191
+ with col2:
192
+
193
+ if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
194
+
195
+ with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
196
+
197
+ preds = st.session_state.preds_3D_ori
198
+ #result_image.CopyInformation(inputImage)
199
+ sitk.WriteImage(preds, tmpfile.name)
200
+ #nib.save(st.session_state.preds_3D, tmpfile.name)
201
+ with open(tmpfile.name, "rb") as f:
202
+ bytes_data = f.read()
203
+ st.download_button(
204
+ label="Download result(.nii.gz)",
205
+ data=bytes_data,
206
+ file_name="segmentation.nii.gz",
207
+ mime="application/octet-stream",
208
+ disabled=False,
209
+ )
210
+
211
+ with col3:
212
+ run_button_name = "Run" if not st.session_state.running else "Running"
213
+ if st.button(
214
+ run_button_name,
215
+ type="primary",
216
+ use_container_width=True,
217
+ disabled=(st.session_state.data_item is None or st.session_state.running),
218
+ ):
219
+ st.session_state.running = True
220
+ st.rerun()
221
+
222
+ if st.session_state.running:
223
+ st.session_state.running = False
224
+ with st.status("Running...", expanded=False) as status:
225
+ run(tmpdirname)
226
+ st.rerun()