FelixzeroSun commited on
Commit
867f0d3
·
1 Parent(s): 2dc96d0
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +19 -49
  3. process_1.py +2 -1
  4. workflow.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -10,25 +10,21 @@ import streamlit as st
10
  from PIL import Image, ImageDraw
11
  from huggingface_hub import snapshot_download
12
 
13
- # =========================
14
- # 配置:两个任务的模型仓库 & 本地路径
15
- # =========================
16
- # 你可以将两个任务分别指向不同的 HF repo;如果都在同一个,也可以都填同一个。
17
  HF_REPOS = {
18
- "Task 1 (MR → CT)": "aehrc/Synthrad2025",
19
- "Task 2 (CBCT → CT)": "aehrc/Synthrad2025", # 如有专门CBCT→CT的repo可在此替换
20
  }
21
  LOCAL_WEIGHTS_DIRS = {
22
  "Task 1 (MR → CT)": os.path.abspath("weights/task1"),
23
  "Task 2 (CBCT → CT)": os.path.abspath("weights/task2"),
24
  }
25
 
26
- # 环境设置(Token)
27
  token = os.getenv("HF_TOKEN")
28
  if token is None:
29
  print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.")
30
 
31
- # 先下载两个任务的权重(如需按需下载,可在选择任务后再下载)
32
  REPO_DIRS = {}
33
  for task_name, repo in HF_REPOS.items():
34
  repo_dir = snapshot_download(
@@ -40,13 +36,11 @@ for task_name, repo in HF_REPOS.items():
40
  )
41
  REPO_DIRS[task_name] = repo_dir
42
 
43
- # nnUNet 环境变量(指向“当前任务”的 results,会在用户切换任务时动态更新)
44
  os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
45
  os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
46
  os.environ["OPENBLAS_NUM_THREADS"] = "1"
47
 
48
- # 从 process.py 导入两个任务的算法类
49
- # 确保你在 process.py 中定义了 SynthradAlgorithm1(MR→CT)和 SynthradAlgorithm2(CBCT→CT)
50
  from process import SynthradAlgorithm2
51
 
52
  from process_1 import SynthradAlgorithm1
@@ -56,21 +50,19 @@ from process_1 import SynthradAlgorithm1
56
  # =========================
57
  st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
58
  st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
59
-
60
- # 任务选择
61
  TASKS = ["Task 1 (MR → CT)", "Task 2 (CBCT → CT)"]
62
  task = st.radio("Select Task", TASKS, index=0, horizontal=True)
63
 
64
- # 根据任务设置标题/提示
65
  if task == "Task 1 (MR → CT)":
66
  vol_label = "MRI volume (.nii/.nii.gz/.mha)"
67
  else:
68
  vol_label = "CBCT volume (.nii/.nii.gz/.mha)"
69
 
70
- # 动态切换 nnUNet 的 results(不同任务使用不同 results 目录)
71
  os.environ["nnUNet_results"] = REPO_DIRS[task]
72
 
73
- # session_state 初始化
74
  if "algos" not in st.session_state:
75
  st.session_state.algos = {}
76
  if "synth_ct" not in st.session_state:
@@ -84,7 +76,6 @@ if "input_vol" not in st.session_state:
84
  if "input_mask" not in st.session_state:
85
  st.session_state.input_mask = None
86
 
87
- # 懒加载对应任务的算法实例
88
  def get_algo(task_name: str):
89
  if task_name not in st.session_state.algos:
90
  if task_name == "Task 1 (MR → CT)":
@@ -98,25 +89,18 @@ algo = get_algo(task)
98
  st.subheader("Input")
99
  src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
100
 
101
- # =========================
102
- # 样例映射(两任务可共用同一份样例,也可按需区分)
103
- # 这里假设 repo_dir 下有如下结构:
104
- # repo_dir/Abdomen/{cbct.mha, mask.mha} 或 {mri.mha, mask.mha}
105
- # repo_dir/Head and Neck/{cbct.mha or mri.mha, mask.mha}
106
- # repo_dir/Thorax/{cbct.mha or mri.mha, mask.mha}
107
- # 如果你的文件名不同,请按需调整。
108
- # =========================
109
 
110
  def build_sample_map(task_name: str):
111
  repo_dir = REPO_DIRS[task_name]
112
  if task_name == "Task 1 (MR → CT)":
113
  vol_key = "mri"
114
- vol_fname = "mr.mha" # 如果你的样例文件名不是 mri.mha,请改成实际名称
115
- mask_fname = "mask1.mha" # 如果你的样例文件名不是 mri.mha,请改成实际名称
116
  else:
117
  vol_key = "cbct"
118
- vol_fname = "cbct.mha" # 如果你的样例文件名不是 cbct.mha,请改成实际名称
119
- mask_fname = "mask2.mha" # 如果你的样例文件名不是 mri.mha,请改成实际名称
120
  sample_map = {
121
  "Abdomen (sample)": {
122
  "region": "Abdomen",
@@ -138,9 +122,7 @@ def build_sample_map(task_name: str):
138
 
139
  SAMPLE_MAP = build_sample_map(task)
140
 
141
- # =========================
142
- # 小工具函数
143
- # =========================
144
  def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
145
  with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
146
  sitk.WriteImage(img, tmp.name)
@@ -181,9 +163,6 @@ def _norm2u8(slice2d):
181
  s = np.clip(s, 0, 1)
182
  return (s * 255).astype(np.uint8)
183
 
184
- # =========================
185
- # 输入区域(Upload or Sample)
186
- # =========================
187
  c1, c2, c3 = st.columns([2, 2, 1])
188
 
189
  if src == "Upload":
@@ -208,9 +187,6 @@ else:
208
 
209
  run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
210
 
211
- # =========================
212
- # 推理
213
- # =========================
214
  if run_btn:
215
  with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR → CT)' else 'SynthradAlgorithm2')}..."):
216
  if src == "Upload":
@@ -221,15 +197,13 @@ if run_btn:
221
  in_vol_img = _read_sitk_from_path(sample["vol"])
222
  mask_img = _read_sitk_from_path(sample["mask"])
223
 
224
- # 保存原始元信息
225
  st.session_state.orig_meta = (
226
  in_vol_img.GetSpacing(),
227
  in_vol_img.GetOrigin(),
228
  in_vol_img.GetDirection(),
229
  )
230
 
231
- # 调用不同任务的算法
232
- # 约定:算法统一使用 dict 输入:{"image": <sitk.Image>, "mask": <sitk.Image>, "region": <str>}
233
  out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run})
234
 
235
  st.session_state.synth_ct = out_img
@@ -237,20 +211,17 @@ if run_btn:
237
  st.session_state.input_vol = in_vol_img
238
  st.session_state.input_mask = mask_img
239
 
240
- # =========================
241
- # 结果与下载
242
- # =========================
243
  if st.session_state.vol_np is None:
244
- st.info("请先选择任务与输入(Upload Sample),然后点击 Run")
245
  else:
246
- # 将输出转为 LPS 方向做显示(可选)
247
  out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
248
  vol = sitk.GetArrayFromImage(out_lps).astype(np.float32)
249
  D, H, W = vol.shape
250
 
251
  col_d1, col_d2, col_d3 = st.columns(3)
252
 
253
- # 下载合成CT
254
  with col_d3:
255
  _download_sitk_image(
256
  st.session_state.synth_ct,
@@ -258,7 +229,7 @@ else:
258
  label="Download synthetic CT",
259
  )
260
 
261
- # 下载输入体积(根据任务区分命名)
262
  with col_d1:
263
  if st.session_state.input_vol is not None:
264
  in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
@@ -271,7 +242,6 @@ else:
271
  else:
272
  st.button("Download input", disabled=True)
273
 
274
- # 下载掩膜
275
  with col_d2:
276
  if st.session_state.input_mask is not None:
277
  _download_sitk_image(
 
10
  from PIL import Image, ImageDraw
11
  from huggingface_hub import snapshot_download
12
 
13
+
 
 
 
14
  HF_REPOS = {
15
+ "Task 1 (MR → CT)": "aehrc/Synthrad2025_task1",
16
+ "Task 2 (CBCT → CT)": "aehrc/Synthrad2025_task2",
17
  }
18
  LOCAL_WEIGHTS_DIRS = {
19
  "Task 1 (MR → CT)": os.path.abspath("weights/task1"),
20
  "Task 2 (CBCT → CT)": os.path.abspath("weights/task2"),
21
  }
22
 
23
+
24
  token = os.getenv("HF_TOKEN")
25
  if token is None:
26
  print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.")
27
 
 
28
  REPO_DIRS = {}
29
  for task_name, repo in HF_REPOS.items():
30
  repo_dir = snapshot_download(
 
36
  )
37
  REPO_DIRS[task_name] = repo_dir
38
 
39
+
40
  os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
41
  os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
42
  os.environ["OPENBLAS_NUM_THREADS"] = "1"
43
 
 
 
44
  from process import SynthradAlgorithm2
45
 
46
  from process_1 import SynthradAlgorithm1
 
50
  # =========================
51
  st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
52
  st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
53
+ st.image("/home/head_neck/sub_2/Synthrad2025/workflow.png",width=800)
 
54
  TASKS = ["Task 1 (MR → CT)", "Task 2 (CBCT → CT)"]
55
  task = st.radio("Select Task", TASKS, index=0, horizontal=True)
56
 
57
+
58
  if task == "Task 1 (MR → CT)":
59
  vol_label = "MRI volume (.nii/.nii.gz/.mha)"
60
  else:
61
  vol_label = "CBCT volume (.nii/.nii.gz/.mha)"
62
 
63
+
64
  os.environ["nnUNet_results"] = REPO_DIRS[task]
65
 
 
66
  if "algos" not in st.session_state:
67
  st.session_state.algos = {}
68
  if "synth_ct" not in st.session_state:
 
76
  if "input_mask" not in st.session_state:
77
  st.session_state.input_mask = None
78
 
 
79
  def get_algo(task_name: str):
80
  if task_name not in st.session_state.algos:
81
  if task_name == "Task 1 (MR → CT)":
 
89
  st.subheader("Input")
90
  src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
91
 
92
+
 
 
 
 
 
 
 
93
 
94
  def build_sample_map(task_name: str):
95
  repo_dir = REPO_DIRS[task_name]
96
  if task_name == "Task 1 (MR → CT)":
97
  vol_key = "mri"
98
+ vol_fname = "mr.mha"
99
+ mask_fname = "mask1.mha"
100
  else:
101
  vol_key = "cbct"
102
+ vol_fname = "cbct.mha"
103
+ mask_fname = "mask2.mha"
104
  sample_map = {
105
  "Abdomen (sample)": {
106
  "region": "Abdomen",
 
122
 
123
  SAMPLE_MAP = build_sample_map(task)
124
 
125
+
 
 
126
  def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
127
  with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
128
  sitk.WriteImage(img, tmp.name)
 
163
  s = np.clip(s, 0, 1)
164
  return (s * 255).astype(np.uint8)
165
 
 
 
 
166
  c1, c2, c3 = st.columns([2, 2, 1])
167
 
168
  if src == "Upload":
 
187
 
188
  run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
189
 
 
 
 
190
  if run_btn:
191
  with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR → CT)' else 'SynthradAlgorithm2')}..."):
192
  if src == "Upload":
 
197
  in_vol_img = _read_sitk_from_path(sample["vol"])
198
  mask_img = _read_sitk_from_path(sample["mask"])
199
 
200
+
201
  st.session_state.orig_meta = (
202
  in_vol_img.GetSpacing(),
203
  in_vol_img.GetOrigin(),
204
  in_vol_img.GetDirection(),
205
  )
206
 
 
 
207
  out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run})
208
 
209
  st.session_state.synth_ct = out_img
 
211
  st.session_state.input_vol = in_vol_img
212
  st.session_state.input_mask = mask_img
213
 
 
 
 
214
  if st.session_state.vol_np is None:
215
+ st.info("Select Upload or Sample, then click Run")
216
  else:
217
+
218
  out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
219
  vol = sitk.GetArrayFromImage(out_lps).astype(np.float32)
220
  D, H, W = vol.shape
221
 
222
  col_d1, col_d2, col_d3 = st.columns(3)
223
 
224
+
225
  with col_d3:
226
  _download_sitk_image(
227
  st.session_state.synth_ct,
 
229
  label="Download synthetic CT",
230
  )
231
 
232
+
233
  with col_d1:
234
  if st.session_state.input_vol is not None:
235
  in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
 
242
  else:
243
  st.button("Download input", disabled=True)
244
 
 
245
  with col_d2:
246
  if st.session_state.input_mask is not None:
247
  _download_sitk_image(
process_1.py CHANGED
@@ -24,10 +24,11 @@ import shutil
24
  import os
25
 
26
  os.environ["OPENBLAS_NUM_THREADS"] = "1"
 
27
  device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
28
 
29
 
30
- force_cpu = os.getenv("FORCE_CPU", "0") == "1"
31
  class SynthradAlgorithm1(BaseSynthradAlgorithm):
32
  """
33
  This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
 
24
  import os
25
 
26
  os.environ["OPENBLAS_NUM_THREADS"] = "1"
27
+ force_cpu = os.getenv("FORCE_CPU", "0") == "1"
28
  device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
29
 
30
 
31
+
32
  class SynthradAlgorithm1(BaseSynthradAlgorithm):
33
  """
34
  This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
workflow.png ADDED

Git LFS Details

  • SHA256: f8475c69f68d15781374831a484b6648dffed781530693905a9964855f98cdbc
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB