FelixzeroSun commited on
Commit
035de9f
·
1 Parent(s): 23ed1a5
Files changed (2) hide show
  1. app_0.py +0 -115
  2. app_1.py +0 -204
app_0.py DELETED
@@ -1,115 +0,0 @@
1
- # app.py
2
- import os
3
- from huggingface_hub import snapshot_download
4
-
5
- HF_REPO = "Synthard2025KoalAI/synthrad2025_task1" # 你的模型仓库
6
- LOCAL_WEIGHTS_DIR = os.path.abspath("weights/task1") # 下载到 Space 工作目录
7
- repo_dir = snapshot_download(
8
- HF_REPO,
9
- repo_type="model",
10
- local_dir=LOCAL_WEIGHTS_DIR,
11
- local_dir_use_symlinks=False,
12
- token=os.getenv("HF_TOKEN"), # 私有模型需要
13
- )
14
-
15
-
16
- os.environ["nnUNet_results"] = repo_dir
17
- os.environ["nnUNet_raw"] = "./nnunet_raw"
18
- os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed"
19
- # os.environ["nnUNet_results"] = "./nnunet_results"
20
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
21
-
22
- import streamlit as st
23
- import numpy as np
24
- import SimpleITK as sitk
25
- import io
26
-
27
- from process import SynthradAlgorithm
28
-
29
- st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
30
- st.title("SynthRad (MRI+Mask → synthetic CT) — Streamlit Demo")
31
-
32
- # ---- 全局算法(避免重复加载模型)----
33
- if "algo" not in st.session_state:
34
- st.session_state.algo = SynthradAlgorithm()
35
- if "synth_ct" not in st.session_state:
36
- st.session_state.synth_ct = None # SimpleITK.Image
37
- if "orig_meta" not in st.session_state:
38
- st.session_state.orig_meta = None # (spacing, origin, direction)
39
- if "vol_np" not in st.session_state:
40
- st.session_state.vol_np = None # numpy (D,H,W)
41
-
42
- # ---- 左上:输入区(保留“框架”风格)----
43
- c1, c2, c3 = st.columns([2, 2, 1])
44
- with c1:
45
- mri_file = st.file_uploader("MRI volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mri")
46
- with c2:
47
- mask_file = st.file_uploader("Mask volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mask")
48
- with c3:
49
- region = st.radio("Region", ["Head and Neck", "Abdomen", "Thorax"], index=1)
50
-
51
- run_btn = st.button("Run", type="primary", disabled=not (mri_file and mask_file))
52
-
53
- def _read_sitk_from_uploaded(f):
54
- # 把上传文件读到 SimpleITK
55
- suffix = ".nii.gz" if f.name.endswith(".nii.gz") else os.path.splitext(f.name)[1]
56
- bio = io.BytesIO(f.read())
57
- # SimpleITK 不能直接读 Bytes,需要写临时文件
58
- import tempfile
59
- with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
60
- tmp.write(bio.getvalue())
61
- tmp.flush()
62
- path = tmp.name
63
- img = sitk.ReadImage(path)
64
- try:
65
- os.remove(path)
66
- except Exception:
67
- pass
68
- return img
69
-
70
- def _norm2u8(slice2d):
71
- s = slice2d.astype(np.float32)
72
- s = (s - np.percentile(s, 1)) / (np.percentile(s, 99) - np.percentile(s, 1) + 1e-6)
73
- s = np.clip(s, 0, 1)
74
- return (s * 255).astype(np.uint8)
75
-
76
- if run_btn:
77
- if not (mri_file and mask_file):
78
- st.warning("请同时上传 MRI 和 Mask")
79
- else:
80
- with st.spinner("Running nnUNetv2 SynthradAlgorithm..."):
81
- mr_img = _read_sitk_from_uploaded(mri_file)
82
- mask_img = _read_sitk_from_uploaded(mask_file)
83
-
84
- st.session_state.orig_meta = (mr_img.GetSpacing(), mr_img.GetOrigin(), mr_img.GetDirection())
85
- out_img = st.session_state.algo.predict({"image": mr_img, "mask": mask_img, "region": region})
86
- st.session_state.synth_ct = out_img
87
- st.session_state.vol_np = sitk.GetArrayFromImage(out_img).astype(np.float32) # (D,H,W)
88
-
89
- if st.session_state.vol_np is None:
90
- st.info("请上传 MRI + Mask 并点击 Run")
91
- else:
92
- vol = st.session_state.vol_np
93
- D, H, W = vol.shape
94
-
95
- colL, colR = st.columns(2)
96
- with colL:
97
- z_idx = st.slider("Axial (Z)", 0, D - 1, D // 2, key="z_idx")
98
- axial = _norm2u8(vol[z_idx, :, :])
99
- st.image(axial, caption=f"Axial slice z={z_idx}", use_column_width=True)
100
-
101
- with colR:
102
- y_idx = st.slider("Coronal (Y)", 0, H - 1, H // 2, key="y_idx")
103
- coronal = _norm2u8(vol[:, y_idx, :])
104
- st.image(coronal, caption=f"Coronal slice y={y_idx}", use_column_width=True)
105
-
106
- spacing, origin, direction = st.session_state.orig_meta
107
- out_path = "synth_ct.nii.gz"
108
- sitk.WriteImage(st.session_state.synth_ct, out_path)
109
- with open(out_path, "rb") as f:
110
- st.download_button(
111
- label="Download synthetic CT (.nii.gz)",
112
- data=f.read(),
113
- file_name="synth_ct.nii.gz",
114
- mime="application/octet-stream",
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_1.py DELETED
@@ -1,204 +0,0 @@
1
- # app.py
2
- import os
3
- from huggingface_hub import snapshot_download
4
-
5
- from PIL import Image, ImageDraw
6
- HF_REPO = "aehrc/Synthrad2025"
7
- LOCAL_WEIGHTS_DIR = os.path.abspath("weights/task1")
8
-
9
- token = os.getenv("HF_TOKEN")
10
- if token is None:
11
-
12
- print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.")
13
-
14
- repo_dir = snapshot_download(
15
- repo_id=HF_REPO,
16
- repo_type="model",
17
- local_dir=LOCAL_WEIGHTS_DIR,
18
- local_dir_use_symlinks=False,
19
- token=token,
20
- )
21
- print(repo_dir)
22
- os.environ["nnUNet_results"] = repo_dir
23
- os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
24
- os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
25
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
26
-
27
- import streamlit as st
28
- import numpy as np
29
- import SimpleITK as sitk
30
- import io
31
-
32
- from process import SynthradAlgorithm2
33
-
34
- st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
35
- st.title("SynthRad (MRI/CBCT + Mask → synthetic CT) — Streamlit Demo")
36
-
37
- if "algo" not in st.session_state:
38
- st.session_state.algo = SynthradAlgorithm2()
39
- if "synth_ct" not in st.session_state:
40
- st.session_state.synth_ct = None
41
- if "orig_meta" not in st.session_state:
42
- st.session_state.orig_meta = None
43
- if "vol_np" not in st.session_state:
44
- st.session_state.vol_np = None
45
- if "input_mr" not in st.session_state:
46
- st.session_state.input_mr = None
47
- if "input_mask" not in st.session_state:
48
- st.session_state.input_mask = None
49
-
50
- st.subheader("Input")
51
- src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
52
-
53
- from huggingface_hub import snapshot_download
54
- import os
55
- import tempfile, zipfile
56
-
57
- def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
58
- with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
59
- sitk.WriteImage(img, tmp.name)
60
- tmp_path = tmp.name
61
- with open(tmp_path, "rb") as f:
62
- st.download_button(
63
- label=label,
64
- data=f.read(),
65
- file_name=file_name,
66
- mime="application/octet-stream",
67
- )
68
- try:
69
- os.remove(tmp_path)
70
- except Exception:
71
- pass
72
- SAMPLES_REPO = "Synthard2025KoalAI/synthrad2025_task1"
73
- # samples_dir = os.path.join(repo_dir, "samples")
74
-
75
- SAMPLE_MAP = {
76
- "Abdomen (sample)": {
77
- "region": "Abdomen",
78
- "mri": os.path.join(repo_dir, "Abdomen", "cbct.mha"),
79
- "mask": os.path.join(repo_dir,"Abdomen", "mask2.mha"),
80
- },
81
- "Head and Neck (sample)": {
82
- "region": "Head and Neck",
83
- "mri": os.path.join(repo_dir, "Head and Neck", "cbct.mha"),
84
- "mask": os.path.join(repo_dir, "Head and Neck", "mask2.mha"),
85
- },
86
- "Thorax (sample)": {
87
- "region": "Thorax",
88
- "mri": os.path.join(repo_dir, "Thorax", "cbct.mha"),
89
- "mask": os.path.join(repo_dir, "Thorax", "mask2.mha"),
90
- },
91
- }
92
- c1, c2, c3 = st.columns([2, 2, 1])
93
-
94
- if src == "Upload":
95
- with c1:
96
- mri_file = st.file_uploader("MRI volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mri")
97
- with c2:
98
- mask_file = st.file_uploader("Mask volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mask")
99
- with c3:
100
- region = st.radio("Region", ["Head and Neck", "Abdomen", "Thorax"], index=1)
101
- inputs_ready = (mri_file is not None) and (mask_file is not None)
102
- region_for_run = region
103
- else:
104
- with c1:
105
- sample_key = st.selectbox("Choose a sample", list(SAMPLE_MAP.keys()))
106
- with c2:
107
- st.markdown("Region (fixed by sample)")
108
- st.write(f"**{SAMPLE_MAP[sample_key]['region']}**")
109
- with c3:
110
- st.markdown(" ", unsafe_allow_html=True)
111
- inputs_ready = (sample_key is not None)
112
- region_for_run = SAMPLE_MAP[sample_key]["region"]
113
-
114
- run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
115
-
116
- def _read_sitk_from_uploaded(f):
117
- suffix = ".nii.gz" if f.name.endswith(".nii.gz") else os.path.splitext(f.name)[1]
118
- bio = io.BytesIO(f.read())
119
- import tempfile
120
- with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
121
- tmp.write(bio.getvalue()); tmp.flush(); path = tmp.name
122
- img = sitk.ReadImage(path)
123
- try: os.remove(path)
124
- except Exception: pass
125
- return img
126
-
127
- def _read_sitk_from_path(path):
128
- if not os.path.exists(path):
129
- st.error(f"Sample file missing: {path}")
130
- st.stop()
131
- return sitk.ReadImage(path)
132
-
133
- def _norm2u8(slice2d):
134
- s = slice2d.astype(np.float32)
135
- s = (s - np.percentile(s, 1)) / (np.percentile(s, 99) - np.percentile(s, 1) + 1e-6)
136
- s = np.clip(s, 0, 1)
137
- return (s * 255).astype(np.uint8)
138
- if run_btn:
139
- with st.spinner("Running nnUNetv2 SynthradAlgorithm..."):
140
- if src == "Upload":
141
- mr_img = _read_sitk_from_uploaded(mri_file)
142
- mask_img = _read_sitk_from_uploaded(mask_file)
143
- else:
144
- sample = SAMPLE_MAP[sample_key]
145
- mr_img = _read_sitk_from_path(sample["mri"])
146
- mask_img = _read_sitk_from_path(sample["mask"])
147
-
148
- st.session_state.orig_meta = (mr_img.GetSpacing(), mr_img.GetOrigin(), mr_img.GetDirection())
149
- out_img = st.session_state.algo.predict({"image": mr_img, "mask": mask_img, "region": region_for_run})
150
- st.session_state.synth_ct = out_img
151
- st.session_state.vol_np = sitk.GetArrayFromImage(out_img).astype(np.float32)
152
- st.session_state.input_mr = mr_img
153
- st.session_state.input_mask = mask_img
154
-
155
-
156
-
157
- if st.session_state.vol_np is None:
158
- st.info("Please upload MRI + Mask or Sample, then Run")
159
- else:
160
- mr_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
161
- vol = sitk.GetArrayFromImage(mr_lps).astype(np.float32)
162
- # vol = st.session_state.vol_np # shape: (D, H, W) = (Z, Y, X)
163
- D, H, W = vol.shape
164
-
165
-
166
- # out_path = "synth_ct.nii.gz"
167
- # sitk.WriteImage(st.session_state.synth_ct, out_path)
168
- # with open(out_path, "rb") as f:
169
- # st.download_button(
170
- # label="Download synthetic CT (.nii.gz)",
171
- # data=f.read(),
172
- # file_name="synth_ct.nii.gz",
173
- # mime="application/octet-stream",
174
- # )
175
- col_d1, col_d2, col_d3 = st.columns(3)
176
-
177
- with col_d3:
178
- _download_sitk_image(
179
- st.session_state.synth_ct,
180
- file_name="synth_ct.nii.gz",
181
- label="Download synthetic CT"
182
- )
183
-
184
- with col_d1:
185
- if st.session_state.input_mr is not None:
186
- _download_sitk_image(
187
- st.session_state.input_mr,
188
- file_name="input_mri.nii.gz", # 如果你其实是 CBCT,可改成 input_cbct.nii.gz
189
- label="Download input MRI/CBCT"
190
- )
191
- else:
192
- st.button("Download input MRI/CBCT", disabled=True)
193
-
194
- with col_d2:
195
- if st.session_state.input_mask is not None:
196
- # (可选)确保掩膜是整型:mask = sitk.Cast(st.session_state.input_mask, sitk.sitkUInt16)
197
- _download_sitk_image(
198
- st.session_state.input_mask,
199
- file_name="input_mask.nii.gz",
200
- label="Download input Mask"
201
- )
202
- else:
203
- st.button("Download input Mask", disabled=True)
204
-