Spaces:
Sleeping
Sleeping
FelixzeroSun
commited on
Commit
·
867f0d3
1
Parent(s):
2dc96d0
debug
Browse files- .gitattributes +1 -0
- app.py +19 -49
- process_1.py +2 -1
- workflow.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -10,25 +10,21 @@ import streamlit as st
|
|
| 10 |
from PIL import Image, ImageDraw
|
| 11 |
from huggingface_hub import snapshot_download
|
| 12 |
|
| 13 |
-
|
| 14 |
-
# 配置:两个任务的模型仓库 & 本地路径
|
| 15 |
-
# =========================
|
| 16 |
-
# 你可以将两个任务分别指向不同的 HF repo;如果都在同一个,也可以都填同一个。
|
| 17 |
HF_REPOS = {
|
| 18 |
-
"Task 1 (MR → CT)": "aehrc/
|
| 19 |
-
"Task 2 (CBCT → CT)": "aehrc/
|
| 20 |
}
|
| 21 |
LOCAL_WEIGHTS_DIRS = {
|
| 22 |
"Task 1 (MR → CT)": os.path.abspath("weights/task1"),
|
| 23 |
"Task 2 (CBCT → CT)": os.path.abspath("weights/task2"),
|
| 24 |
}
|
| 25 |
|
| 26 |
-
|
| 27 |
token = os.getenv("HF_TOKEN")
|
| 28 |
if token is None:
|
| 29 |
print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.")
|
| 30 |
|
| 31 |
-
# 先下载两个任务的权重(如需按需下载,可在选择任务后再下载)
|
| 32 |
REPO_DIRS = {}
|
| 33 |
for task_name, repo in HF_REPOS.items():
|
| 34 |
repo_dir = snapshot_download(
|
|
@@ -40,13 +36,11 @@ for task_name, repo in HF_REPOS.items():
|
|
| 40 |
)
|
| 41 |
REPO_DIRS[task_name] = repo_dir
|
| 42 |
|
| 43 |
-
|
| 44 |
os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
|
| 45 |
os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
|
| 46 |
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
| 47 |
|
| 48 |
-
# 从 process.py 导入两个任务的算法类
|
| 49 |
-
# 确保你在 process.py 中定义了 SynthradAlgorithm1(MR→CT)和 SynthradAlgorithm2(CBCT→CT)
|
| 50 |
from process import SynthradAlgorithm2
|
| 51 |
|
| 52 |
from process_1 import SynthradAlgorithm1
|
|
@@ -56,21 +50,19 @@ from process_1 import SynthradAlgorithm1
|
|
| 56 |
# =========================
|
| 57 |
st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
|
| 58 |
st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
|
| 59 |
-
|
| 60 |
-
# 任务选择
|
| 61 |
TASKS = ["Task 1 (MR → CT)", "Task 2 (CBCT → CT)"]
|
| 62 |
task = st.radio("Select Task", TASKS, index=0, horizontal=True)
|
| 63 |
|
| 64 |
-
|
| 65 |
if task == "Task 1 (MR → CT)":
|
| 66 |
vol_label = "MRI volume (.nii/.nii.gz/.mha)"
|
| 67 |
else:
|
| 68 |
vol_label = "CBCT volume (.nii/.nii.gz/.mha)"
|
| 69 |
|
| 70 |
-
|
| 71 |
os.environ["nnUNet_results"] = REPO_DIRS[task]
|
| 72 |
|
| 73 |
-
# session_state 初始化
|
| 74 |
if "algos" not in st.session_state:
|
| 75 |
st.session_state.algos = {}
|
| 76 |
if "synth_ct" not in st.session_state:
|
|
@@ -84,7 +76,6 @@ if "input_vol" not in st.session_state:
|
|
| 84 |
if "input_mask" not in st.session_state:
|
| 85 |
st.session_state.input_mask = None
|
| 86 |
|
| 87 |
-
# 懒加载对应任务的算法实例
|
| 88 |
def get_algo(task_name: str):
|
| 89 |
if task_name not in st.session_state.algos:
|
| 90 |
if task_name == "Task 1 (MR → CT)":
|
|
@@ -98,25 +89,18 @@ algo = get_algo(task)
|
|
| 98 |
st.subheader("Input")
|
| 99 |
src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
# 样例映射(两任务可共用同一份样例,也可按需区分)
|
| 103 |
-
# 这里假设 repo_dir 下有如下结构:
|
| 104 |
-
# repo_dir/Abdomen/{cbct.mha, mask.mha} 或 {mri.mha, mask.mha}
|
| 105 |
-
# repo_dir/Head and Neck/{cbct.mha or mri.mha, mask.mha}
|
| 106 |
-
# repo_dir/Thorax/{cbct.mha or mri.mha, mask.mha}
|
| 107 |
-
# 如果你的文件名不同,请按需调整。
|
| 108 |
-
# =========================
|
| 109 |
|
| 110 |
def build_sample_map(task_name: str):
|
| 111 |
repo_dir = REPO_DIRS[task_name]
|
| 112 |
if task_name == "Task 1 (MR → CT)":
|
| 113 |
vol_key = "mri"
|
| 114 |
-
vol_fname = "mr.mha"
|
| 115 |
-
mask_fname = "mask1.mha"
|
| 116 |
else:
|
| 117 |
vol_key = "cbct"
|
| 118 |
-
vol_fname = "cbct.mha"
|
| 119 |
-
mask_fname = "mask2.mha"
|
| 120 |
sample_map = {
|
| 121 |
"Abdomen (sample)": {
|
| 122 |
"region": "Abdomen",
|
|
@@ -138,9 +122,7 @@ def build_sample_map(task_name: str):
|
|
| 138 |
|
| 139 |
SAMPLE_MAP = build_sample_map(task)
|
| 140 |
|
| 141 |
-
|
| 142 |
-
# 小工具函数
|
| 143 |
-
# =========================
|
| 144 |
def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
|
| 145 |
with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
|
| 146 |
sitk.WriteImage(img, tmp.name)
|
|
@@ -181,9 +163,6 @@ def _norm2u8(slice2d):
|
|
| 181 |
s = np.clip(s, 0, 1)
|
| 182 |
return (s * 255).astype(np.uint8)
|
| 183 |
|
| 184 |
-
# =========================
|
| 185 |
-
# 输入区域(Upload or Sample)
|
| 186 |
-
# =========================
|
| 187 |
c1, c2, c3 = st.columns([2, 2, 1])
|
| 188 |
|
| 189 |
if src == "Upload":
|
|
@@ -208,9 +187,6 @@ else:
|
|
| 208 |
|
| 209 |
run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
|
| 210 |
|
| 211 |
-
# =========================
|
| 212 |
-
# 推理
|
| 213 |
-
# =========================
|
| 214 |
if run_btn:
|
| 215 |
with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR → CT)' else 'SynthradAlgorithm2')}..."):
|
| 216 |
if src == "Upload":
|
|
@@ -221,15 +197,13 @@ if run_btn:
|
|
| 221 |
in_vol_img = _read_sitk_from_path(sample["vol"])
|
| 222 |
mask_img = _read_sitk_from_path(sample["mask"])
|
| 223 |
|
| 224 |
-
|
| 225 |
st.session_state.orig_meta = (
|
| 226 |
in_vol_img.GetSpacing(),
|
| 227 |
in_vol_img.GetOrigin(),
|
| 228 |
in_vol_img.GetDirection(),
|
| 229 |
)
|
| 230 |
|
| 231 |
-
# 调用不同任务的算法
|
| 232 |
-
# 约定:算法统一使用 dict 输入:{"image": <sitk.Image>, "mask": <sitk.Image>, "region": <str>}
|
| 233 |
out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run})
|
| 234 |
|
| 235 |
st.session_state.synth_ct = out_img
|
|
@@ -237,20 +211,17 @@ if run_btn:
|
|
| 237 |
st.session_state.input_vol = in_vol_img
|
| 238 |
st.session_state.input_mask = mask_img
|
| 239 |
|
| 240 |
-
# =========================
|
| 241 |
-
# 结果与下载
|
| 242 |
-
# =========================
|
| 243 |
if st.session_state.vol_np is None:
|
| 244 |
-
st.info("
|
| 245 |
else:
|
| 246 |
-
|
| 247 |
out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
|
| 248 |
vol = sitk.GetArrayFromImage(out_lps).astype(np.float32)
|
| 249 |
D, H, W = vol.shape
|
| 250 |
|
| 251 |
col_d1, col_d2, col_d3 = st.columns(3)
|
| 252 |
|
| 253 |
-
|
| 254 |
with col_d3:
|
| 255 |
_download_sitk_image(
|
| 256 |
st.session_state.synth_ct,
|
|
@@ -258,7 +229,7 @@ else:
|
|
| 258 |
label="Download synthetic CT",
|
| 259 |
)
|
| 260 |
|
| 261 |
-
|
| 262 |
with col_d1:
|
| 263 |
if st.session_state.input_vol is not None:
|
| 264 |
in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
|
|
@@ -271,7 +242,6 @@ else:
|
|
| 271 |
else:
|
| 272 |
st.button("Download input", disabled=True)
|
| 273 |
|
| 274 |
-
# 下载掩膜
|
| 275 |
with col_d2:
|
| 276 |
if st.session_state.input_mask is not None:
|
| 277 |
_download_sitk_image(
|
|
|
|
| 10 |
from PIL import Image, ImageDraw
|
| 11 |
from huggingface_hub import snapshot_download
|
| 12 |
|
| 13 |
+
|
|
|
|
|
|
|
|
|
|
| 14 |
HF_REPOS = {
|
| 15 |
+
"Task 1 (MR → CT)": "aehrc/Synthrad2025_task1",
|
| 16 |
+
"Task 2 (CBCT → CT)": "aehrc/Synthrad2025_task2",
|
| 17 |
}
|
| 18 |
LOCAL_WEIGHTS_DIRS = {
|
| 19 |
"Task 1 (MR → CT)": os.path.abspath("weights/task1"),
|
| 20 |
"Task 2 (CBCT → CT)": os.path.abspath("weights/task2"),
|
| 21 |
}
|
| 22 |
|
| 23 |
+
|
| 24 |
token = os.getenv("HF_TOKEN")
|
| 25 |
if token is None:
|
| 26 |
print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.")
|
| 27 |
|
|
|
|
| 28 |
REPO_DIRS = {}
|
| 29 |
for task_name, repo in HF_REPOS.items():
|
| 30 |
repo_dir = snapshot_download(
|
|
|
|
| 36 |
)
|
| 37 |
REPO_DIRS[task_name] = repo_dir
|
| 38 |
|
| 39 |
+
|
| 40 |
os.environ.setdefault("nnUNet_raw", "./nnunet_raw")
|
| 41 |
os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed")
|
| 42 |
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
| 43 |
|
|
|
|
|
|
|
| 44 |
from process import SynthradAlgorithm2
|
| 45 |
|
| 46 |
from process_1 import SynthradAlgorithm1
|
|
|
|
| 50 |
# =========================
|
| 51 |
st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide")
|
| 52 |
st.title("SynthRad — MRI/CBCT + Mask → synthetic CT")
|
| 53 |
+
st.image("/home/head_neck/sub_2/Synthrad2025/workflow.png",width=800)
|
|
|
|
| 54 |
TASKS = ["Task 1 (MR → CT)", "Task 2 (CBCT → CT)"]
|
| 55 |
task = st.radio("Select Task", TASKS, index=0, horizontal=True)
|
| 56 |
|
| 57 |
+
|
| 58 |
if task == "Task 1 (MR → CT)":
|
| 59 |
vol_label = "MRI volume (.nii/.nii.gz/.mha)"
|
| 60 |
else:
|
| 61 |
vol_label = "CBCT volume (.nii/.nii.gz/.mha)"
|
| 62 |
|
| 63 |
+
|
| 64 |
os.environ["nnUNet_results"] = REPO_DIRS[task]
|
| 65 |
|
|
|
|
| 66 |
if "algos" not in st.session_state:
|
| 67 |
st.session_state.algos = {}
|
| 68 |
if "synth_ct" not in st.session_state:
|
|
|
|
| 76 |
if "input_mask" not in st.session_state:
|
| 77 |
st.session_state.input_mask = None
|
| 78 |
|
|
|
|
| 79 |
def get_algo(task_name: str):
|
| 80 |
if task_name not in st.session_state.algos:
|
| 81 |
if task_name == "Task 1 (MR → CT)":
|
|
|
|
| 89 |
st.subheader("Input")
|
| 90 |
src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True)
|
| 91 |
|
| 92 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def build_sample_map(task_name: str):
|
| 95 |
repo_dir = REPO_DIRS[task_name]
|
| 96 |
if task_name == "Task 1 (MR → CT)":
|
| 97 |
vol_key = "mri"
|
| 98 |
+
vol_fname = "mr.mha"
|
| 99 |
+
mask_fname = "mask1.mha"
|
| 100 |
else:
|
| 101 |
vol_key = "cbct"
|
| 102 |
+
vol_fname = "cbct.mha"
|
| 103 |
+
mask_fname = "mask2.mha"
|
| 104 |
sample_map = {
|
| 105 |
"Abdomen (sample)": {
|
| 106 |
"region": "Abdomen",
|
|
|
|
| 122 |
|
| 123 |
SAMPLE_MAP = build_sample_map(task)
|
| 124 |
|
| 125 |
+
|
|
|
|
|
|
|
| 126 |
def _download_sitk_image(img: sitk.Image, file_name: str, label: str):
|
| 127 |
with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp:
|
| 128 |
sitk.WriteImage(img, tmp.name)
|
|
|
|
| 163 |
s = np.clip(s, 0, 1)
|
| 164 |
return (s * 255).astype(np.uint8)
|
| 165 |
|
|
|
|
|
|
|
|
|
|
| 166 |
c1, c2, c3 = st.columns([2, 2, 1])
|
| 167 |
|
| 168 |
if src == "Upload":
|
|
|
|
| 187 |
|
| 188 |
run_btn = st.button("Run", type="primary", disabled=not inputs_ready)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
if run_btn:
|
| 191 |
with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR → CT)' else 'SynthradAlgorithm2')}..."):
|
| 192 |
if src == "Upload":
|
|
|
|
| 197 |
in_vol_img = _read_sitk_from_path(sample["vol"])
|
| 198 |
mask_img = _read_sitk_from_path(sample["mask"])
|
| 199 |
|
| 200 |
+
|
| 201 |
st.session_state.orig_meta = (
|
| 202 |
in_vol_img.GetSpacing(),
|
| 203 |
in_vol_img.GetOrigin(),
|
| 204 |
in_vol_img.GetDirection(),
|
| 205 |
)
|
| 206 |
|
|
|
|
|
|
|
| 207 |
out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run})
|
| 208 |
|
| 209 |
st.session_state.synth_ct = out_img
|
|
|
|
| 211 |
st.session_state.input_vol = in_vol_img
|
| 212 |
st.session_state.input_mask = mask_img
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
if st.session_state.vol_np is None:
|
| 215 |
+
st.info("Select Upload or Sample, then click Run")
|
| 216 |
else:
|
| 217 |
+
|
| 218 |
out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS")
|
| 219 |
vol = sitk.GetArrayFromImage(out_lps).astype(np.float32)
|
| 220 |
D, H, W = vol.shape
|
| 221 |
|
| 222 |
col_d1, col_d2, col_d3 = st.columns(3)
|
| 223 |
|
| 224 |
+
|
| 225 |
with col_d3:
|
| 226 |
_download_sitk_image(
|
| 227 |
st.session_state.synth_ct,
|
|
|
|
| 229 |
label="Download synthetic CT",
|
| 230 |
)
|
| 231 |
|
| 232 |
+
|
| 233 |
with col_d1:
|
| 234 |
if st.session_state.input_vol is not None:
|
| 235 |
in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz"
|
|
|
|
| 242 |
else:
|
| 243 |
st.button("Download input", disabled=True)
|
| 244 |
|
|
|
|
| 245 |
with col_d2:
|
| 246 |
if st.session_state.input_mask is not None:
|
| 247 |
_download_sitk_image(
|
process_1.py
CHANGED
|
@@ -24,10 +24,11 @@ import shutil
|
|
| 24 |
import os
|
| 25 |
|
| 26 |
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
|
|
|
| 27 |
device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
|
| 28 |
|
| 29 |
|
| 30 |
-
|
| 31 |
class SynthradAlgorithm1(BaseSynthradAlgorithm):
|
| 32 |
"""
|
| 33 |
This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
|
|
|
|
| 24 |
import os
|
| 25 |
|
| 26 |
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
| 27 |
+
force_cpu = os.getenv("FORCE_CPU", "0") == "1"
|
| 28 |
device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
|
| 29 |
|
| 30 |
|
| 31 |
+
|
| 32 |
class SynthradAlgorithm1(BaseSynthradAlgorithm):
|
| 33 |
"""
|
| 34 |
This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
|
workflow.png
ADDED
|
Git LFS Details
|