# app.py import os import io import tempfile import zipfile import numpy as np import SimpleITK as sitk import streamlit as st from PIL import Image, ImageDraw from huggingface_hub import snapshot_download HF_REPOS = { "Task 1 (MR → CT)": "aehrc/Synthrad2025_task1", "Task 2 (CBCT → CT)": "aehrc/Synthrad2025_task2", } LOCAL_WEIGHTS_DIRS = { "Task 1 (MR → CT)": os.path.abspath("weights/task1"), "Task 2 (CBCT → CT)": os.path.abspath("weights/task2"), } token = os.getenv("HF_TOKEN") if token is None: print("[Warn] HF_TOKEN not set. If the model repo is private, set it in Settings → Variables and secrets.") REPO_DIRS = {} for task_name, repo in HF_REPOS.items(): repo_dir = snapshot_download( repo_id=repo, repo_type="model", local_dir=LOCAL_WEIGHTS_DIRS[task_name], local_dir_use_symlinks=False, token=token, ) REPO_DIRS[task_name] = repo_dir os.environ.setdefault("nnUNet_raw", "./nnunet_raw") os.environ.setdefault("nnUNet_preprocessed", "./nnunet_preprocessed") os.environ["OPENBLAS_NUM_THREADS"] = "1" from process import SynthradAlgorithm2 from process_1 import SynthradAlgorithm1 st.set_page_config(page_title="SynthRad (nnUNetv2) Demo", layout="wide") st.title("SynthRad — MRI/CBCT + Mask → synthetic CT") st.image("./workflow.png",width=800) TASKS = ["Task 1 (MR → CT)", "Task 2 (CBCT → CT)"] task = st.radio("Select Task", TASKS, index=0, horizontal=True) if task == "Task 1 (MR → CT)": vol_label = "MRI volume (.nii/.nii.gz/.mha)" else: vol_label = "CBCT volume (.nii/.nii.gz/.mha)" os.environ["nnUNet_results"] = REPO_DIRS[task] if "algos" not in st.session_state: st.session_state.algos = {} if "synth_ct" not in st.session_state: st.session_state.synth_ct = None if "orig_meta" not in st.session_state: st.session_state.orig_meta = None if "vol_np" not in st.session_state: st.session_state.vol_np = None if "input_vol" not in st.session_state: st.session_state.input_vol = None if "input_mask" not in st.session_state: st.session_state.input_mask = None def get_algo(task_name: str): if task_name not in st.session_state.algos: if task_name == "Task 1 (MR → CT)": st.session_state.algos[task_name] = SynthradAlgorithm1() else: st.session_state.algos[task_name] = SynthradAlgorithm2() return st.session_state.algos[task_name] algo = get_algo(task) st.subheader("Input") src = st.radio("Source", ["Sample", "Upload"], index=0, horizontal=True) def build_sample_map(task_name: str): repo_dir = REPO_DIRS[task_name] if task_name == "Task 1 (MR → CT)": vol_fname = "mr.mha" mask_fname = "mask1.mha" else: vol_fname = "cbct.mha" mask_fname = "mask2.mha" def pack(region_dir): vol_path = os.path.join(repo_dir, region_dir, vol_fname) mask_path = os.path.join(repo_dir, region_dir, mask_fname) gt_path = os.path.join(repo_dir, region_dir, "ct.mha") # 约定:GT=ct.mha return {"vol": vol_path, "mask": mask_path, "gt": gt_path} sample_map = { "Abdomen (sample)": {"region": "Abdomen", **pack("Abdomen")}, "Head and Neck (sample)": {"region": "Head and Neck", **pack("Head and Neck")}, "Thorax (sample)": {"region": "Thorax", **pack("Thorax")}, } return sample_map SAMPLE_MAP = build_sample_map(task) def _download_sitk_image(img: sitk.Image, file_name: str, label: str): with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as tmp: sitk.WriteImage(img, tmp.name) tmp_path = tmp.name with open(tmp_path, "rb") as f: st.download_button( label=label, data=f.read(), file_name=file_name, mime="application/octet-stream", ) try: os.remove(tmp_path) except Exception: pass def _read_sitk_from_uploaded(f): suffix = ".nii.gz" if f.name.endswith(".nii.gz") else os.path.splitext(f.name)[1] bio = io.BytesIO(f.read()) with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(bio.getvalue()); tmp.flush(); path = tmp.name img = sitk.ReadImage(path) try: os.remove(path) except Exception: pass return img def _read_sitk_from_path(path): if not os.path.exists(path): st.error(f"Sample file missing: {path}") st.stop() return sitk.ReadImage(path) def _norm2u8(slice2d): s = slice2d.astype(np.float32) s = (s - np.percentile(s, 1)) / (np.percentile(s, 99) - np.percentile(s, 1) + 1e-6) s = np.clip(s, 0, 1) return (s * 255).astype(np.uint8) c1, c2, c3 = st.columns([2, 2, 1]) if src == "Upload": with c1: vol_file = st.file_uploader(vol_label, type=["nii", "nii.gz", "mha"], key="vol") with c2: mask_file = st.file_uploader("Mask volume (.nii/.nii.gz/.mha)", type=["nii", "nii.gz", "mha"], key="mask") with c3: region = st.radio("Region", ["Head and Neck", "Abdomen", "Thorax"], index=1) inputs_ready = (vol_file is not None) and (mask_file is not None) region_for_run = region else: with c1: sample_key = st.selectbox("Choose a sample", list(SAMPLE_MAP.keys())) with c2: st.markdown("Region (fixed by sample)") st.write(f"**{SAMPLE_MAP[sample_key]['region']}**") with c3: st.markdown(" ", unsafe_allow_html=True) inputs_ready = (sample_key is not None) region_for_run = SAMPLE_MAP[sample_key]["region"] run_btn = st.button("Run", type="primary", disabled=not inputs_ready) if run_btn: with st.spinner(f"Running nnUNetv2 {('SynthradAlgorithm1' if task=='Task 1 (MR → CT)' else 'SynthradAlgorithm2')}..."): if src == "Upload": in_vol_img = _read_sitk_from_uploaded(vol_file) mask_img = _read_sitk_from_uploaded(mask_file) else: sample = SAMPLE_MAP[sample_key] in_vol_img = _read_sitk_from_path(sample["vol"]) mask_img = _read_sitk_from_path(sample["mask"]) st.session_state.orig_meta = ( in_vol_img.GetSpacing(), in_vol_img.GetOrigin(), in_vol_img.GetDirection(), ) out_img = algo.predict({"image": in_vol_img, "mask": mask_img, "region": region_for_run}) st.session_state.synth_ct = out_img st.session_state.vol_np = sitk.GetArrayFromImage(out_img).astype(np.float32) st.session_state.input_vol = in_vol_img st.session_state.input_mask = mask_img if st.session_state.vol_np is None: st.info("Select Upload or Sample, then click Run") else: in_lps = sitk.DICOMOrient(st.session_state.input_vol, "LPS") out_lps = sitk.DICOMOrient(st.session_state.synth_ct, "LPS") res = sitk.ResampleImageFilter() res.SetReferenceImage(in_lps) res.SetInterpolator(sitk.sitkLinear) res.SetOutputPixelType(out_lps.GetPixelID()) out_on_input = res.Execute(out_lps) gt_on_input = None if src == "Sample": gt_path = SAMPLE_MAP[sample_key].get("gt", None) if gt_path and os.path.exists(gt_path): gt_img = sitk.DICOMOrient(sitk.ReadImage(gt_path), "LPS") res.SetReferenceImage(in_lps) res.SetInterpolator(sitk.sitkLinear) res.SetOutputPixelType(gt_img.GetPixelID()) gt_on_input = res.Execute(gt_img) # numpy in_vol = sitk.GetArrayFromImage(in_lps).astype(np.float32) syn_vol = sitk.GetArrayFromImage(out_on_input).astype(np.float32) gt_vol = sitk.GetArrayFromImage(gt_on_input).astype(np.float32) if gt_on_input is not None else None st.subheader("Input vs Synthetic CT Viewer (Axial only)") n_slices = in_vol.shape[0] idx = st.slider("Slice index (Axial/Z)", 0, n_slices - 1, n_slices // 2) def get_axial(arr, k): return arr[k, :, :] sl_in = get_axial(in_vol, idx) sl_syn = get_axial(syn_vol, idx) img_in = _norm2u8(sl_in) img_syn = _norm2u8(sl_syn) img_gt = _norm2u8(get_axial(gt_vol, idx)) if gt_vol is not None else None overlay_mask = st.checkbox("Overlay mask (red)") alpha = st.slider("Mask opacity", 0.0, 1.0, 0.35, 0.05, disabled=not overlay_mask) mask_slice = None if overlay_mask and st.session_state.input_mask is not None: mask_lps = sitk.DICOMOrient(st.session_state.input_mask, "LPS") res_nn = sitk.ResampleImageFilter() res_nn.SetReferenceImage(in_lps) res_nn.SetInterpolator(sitk.sitkNearestNeighbor) mask_on_input = res_nn.Execute(mask_lps) mask_np = sitk.GetArrayFromImage(mask_on_input) mask_slice = get_axial(mask_np, min(idx, mask_np.shape[0]-1)) mask_plot = np.where(mask_slice > 0, 1.0, np.nan) else: mask_plot = None import plotly.graph_objects as go from plotly.subplots import make_subplots sx, sy, _ = in_lps.GetSpacing() xs = np.arange(img_in.shape[1]) * sx ys = np.arange(img_in.shape[0]) * sy cols = 3 if (src == "Sample" and img_gt is not None) else 2 titles = ["Input (MRI/CBCT)", "Synthetic CT"] + (["Ground-Truth CT"] if cols == 3 else []) fig = make_subplots(rows=1, cols=cols, subplot_titles=tuple(titles)) fig.add_trace(go.Heatmap(z=img_in, x=xs, y=ys, colorscale="gray", zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=1) # synCT fig.add_trace(go.Heatmap(z=img_syn, x=xs, y=ys, colorscale="gray", zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=2) # GT if cols == 3: fig.add_trace(go.Heatmap(z=img_gt, x=xs, y=ys, colorscale="gray", zmin=0, zmax=255, showscale=False, hoverinfo="skip"), row=1, col=3) # mask overlay if mask_plot is not None: red_scale = [[0.0, "rgba(255,0,0,1.0)"], [1.0, "rgba(255,0,0,1.0)"]] for c in range(1, cols+1): fig.add_trace(go.Heatmap(z=mask_plot, x=xs, y=ys, colorscale=red_scale, showscale=False, opacity=alpha, hoverinfo="skip"), row=1, col=c) for c in range(1, cols+1): fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c) fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=c) fig.update_layout(height=600, margin=dict(l=10, r=10, t=40, b=10)) st.plotly_chart(fig, use_container_width=True) # Caption if cols == 3: st.caption(f"Axial (Z) slice {idx+1}/{n_slices} — All aligned to input geometry; GT only for samples.") else: st.caption(f"Axial (Z) slice {idx+1}/{n_slices} — Aligned to input geometry.") col_d1, col_d2, col_d3 = st.columns(3) with col_d3: _download_sitk_image(st.session_state.synth_ct, file_name="synth_ct.nii.gz", label="Download synthetic CT") with col_d1: if st.session_state.input_vol is not None: in_name = "input_mr.nii.gz" if task == "Task 1 (MR → CT)" else "input_cbct.nii.gz" in_label = "Download input MRI" if task == "Task 1 (MR → CT)" else "Download input CBCT" _download_sitk_image(st.session_state.input_vol, file_name=in_name, label=in_label) else: st.button("Download input", disabled=True) with col_d2: if st.session_state.input_mask is not None: _download_sitk_image(st.session_state.input_mask, file_name="input_mask.nii.gz", label="Download input Mask") else: st.button("Download input Mask", disabled=True)