| """
|
| Standalone Inference Script β Improved ICH Screening (2.5D, 5-fold Ensemble)
|
| =============================================================================
|
| Reads raw DICOM CT brain slices, reproduces improved preprocessing from the
|
| improvement notebooks, runs 5-fold EfficientNet-B4 ensemble inference, applies
|
| saved calibration, and generates:
|
| β’ Per-image JSON reports (fixed schema)
|
| β’ Slice-level CSV summary
|
| β’ Patient-level CSV summary
|
|
|
| No command-line arguments β all paths are configured in the CONFIG section.
|
|
|
| Requirements:
|
| pip install torch timm pydicom opencv-python-headless numpy pandas scikit-learn
|
|
|
| Usage:
|
| python run_inference.py
|
| """
|
|
|
| import datetime
|
| import json
|
| import pickle
|
| import warnings
|
| from pathlib import Path
|
| from typing import Dict, List, Optional, Tuple
|
|
|
| import cv2
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
| import torch.nn as nn
|
| import timm
|
|
|
|
|
| has_pydicom = False
|
| try:
|
| import pydicom
|
| import pydicom.multival
|
|
|
|
|
|
|
| warnings.filterwarnings(
|
| "ignore",
|
| message=r"Invalid value for VR UI:",
|
| category=UserWarning,
|
| module=r"pydicom\.valuerep",
|
| )
|
| has_pydicom = True
|
| except ImportError:
|
| pass
|
|
|
|
|
|
|
|
|
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent
|
|
|
|
|
| FOLD_MODEL_PATHS = [SCRIPT_DIR / f"best_model_fold{i}.pth" for i in range(5)]
|
| CALIB_PARAMS_PATH = SCRIPT_DIR / "calibration_params.json"
|
| ISOTONIC_MODELS_PATH = SCRIPT_DIR / "isotonic_models.pkl"
|
| NORM_STATS_PATH = SCRIPT_DIR / "normalization_stats.json"
|
|
|
|
|
| DICOM_INPUT_DIR = Path(r"D:\major8thsem\stage_2_test")
|
|
|
|
|
| MANIFEST_PATH = SCRIPT_DIR / "manifest.csv"
|
|
|
|
|
| OUTPUT_DIR = SCRIPT_DIR / "outputs"
|
|
|
|
|
| BACKBONE = "tf_efficientnet_b4"
|
| IMG_SIZE = 380
|
| IN_CHANNELS = 9
|
| N_CLASSES = 6
|
| DROPOUT = 0.4
|
| DROP_PATH = 0.2
|
|
|
|
|
| PATIENT_AGG_METHOD = "topk_mean"
|
| PATIENT_TOPK = 3
|
| DECISION_THRESHOLD = None
|
| FOLD_SELECTION = "ensemble"
|
| GENERATE_HEATMAPS = True
|
|
|
| WINDOWS = [
|
| (40, 80),
|
| (75, 215),
|
| (40, 380),
|
| ]
|
|
|
| SUBTYPES = [
|
| "any",
|
| "epidural",
|
| "intraparenchymal",
|
| "intraventricular",
|
| "subarachnoid",
|
| "subdural",
|
| ]
|
|
|
| OUTCOME_POSITIVE = "Hemorrhage indicator detected"
|
| OUTCOME_NEGATIVE = "No hemorrhage indicator detected"
|
|
|
| BAND_LABELS = {
|
| "HIGH": "High confidence",
|
| "MEDIUM": "Moderate confidence",
|
| "LOW": "Low confidence",
|
| }
|
|
|
| TRIAGE_ACTIONS = {
|
| ("POSITIVE", "HIGH"): "Urgent radiologist review recommended",
|
| ("POSITIVE", "MEDIUM"): "Prioritised radiologist review recommended",
|
| ("POSITIVE", "LOW"): "Radiologist review recommended β low confidence",
|
| ("NEGATIVE", "HIGH"): "Standard workflow β no urgent action",
|
| ("NEGATIVE", "MEDIUM"): "Standard workflow β manual review if clinically indicated",
|
| ("NEGATIVE", "LOW"): "Manual review recommended β model uncertainty high",
|
| }
|
|
|
| DISCLAIMER = (
|
| "This report is produced by an AI-assisted screening tool and does NOT "
|
| "constitute a medical diagnosis. All screening findings must be reviewed "
|
| "and confirmed by a qualified, licensed medical professional before any "
|
| "clinical decision is made. The system is intended solely as a "
|
| "decision-support aid in a screening workflow and is not cleared for "
|
| "standalone diagnostic use."
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _to_scalar(val) -> float:
|
| if has_pydicom and isinstance(val, (list, pydicom.multival.MultiValue)):
|
| return float(val[0])
|
| return float(val)
|
|
|
|
|
| def apply_window(img_hu: np.ndarray, wc: float, ww: float) -> np.ndarray:
|
| lo = wc - ww / 2
|
| hi = wc + ww / 2
|
| return np.clip((img_hu - lo) / (hi - lo), 0.0, 1.0)
|
|
|
|
|
| def load_single_dicom_3ch(dcm_path: Path, size: int = IMG_SIZE) -> np.ndarray:
|
| dcm = pydicom.dcmread(str(dcm_path))
|
| img = dcm.pixel_array.astype(np.float32)
|
|
|
| slope = _to_scalar(getattr(dcm, "RescaleSlope", 1))
|
| inter = _to_scalar(getattr(dcm, "RescaleIntercept", 0))
|
| img = img * slope + inter
|
|
|
| channels = []
|
| for wc, ww in WINDOWS:
|
| ch = apply_window(img, wc, ww)
|
| ch = cv2.resize(ch, (size, size), interpolation=cv2.INTER_AREA)
|
| channels.append(ch)
|
|
|
| return np.stack(channels, axis=-1).astype(np.float32)
|
|
|
|
|
| def build_adjacency(dicom_dir: Path) -> pd.DataFrame:
|
| records: List[dict] = []
|
| for dcm_path in sorted(dicom_dir.glob("*.dcm")):
|
| image_id = dcm_path.stem
|
| try:
|
| dcm = pydicom.dcmread(str(dcm_path), stop_before_pixels=True)
|
| patient_id = str(getattr(dcm, "PatientID", "UNKNOWN"))
|
| series_uid = str(getattr(dcm, "SeriesInstanceUID", "UNKNOWN_SERIES"))
|
|
|
| ipp = getattr(dcm, "ImagePositionPatient", None)
|
| if ipp is not None and len(ipp) >= 3:
|
| z_pos = float(ipp[2])
|
| else:
|
| z_pos = float(getattr(dcm, "SliceLocation", 0.0))
|
| except Exception:
|
| patient_id = "UNKNOWN"
|
| series_uid = "UNKNOWN_SERIES"
|
| z_pos = 0.0
|
|
|
| records.append(
|
| {
|
| "image_id": image_id,
|
| "patient_id": patient_id,
|
| "series_uid": series_uid,
|
| "z_pos": z_pos,
|
| "dcm_path": str(dcm_path),
|
| }
|
| )
|
|
|
| if not records:
|
| return pd.DataFrame(columns=["image_id", "patient_id", "series_uid", "z_pos", "dcm_path", "prev_image_id", "next_image_id"])
|
|
|
| df = pd.DataFrame(records)
|
| df = df.sort_values(["patient_id", "series_uid", "z_pos"]).reset_index(drop=True)
|
| df["prev_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(1)
|
| df["next_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(-1)
|
| return df
|
|
|
|
|
| def build_9ch_for_row(row: pd.Series, image_path_map: Dict[str, Path], mean_9: np.ndarray, std_9: np.ndarray) -> np.ndarray:
|
| center_id = row["image_id"]
|
| prev_id = row["prev_image_id"] if pd.notna(row.get("prev_image_id")) else None
|
| next_id = row["next_image_id"] if pd.notna(row.get("next_image_id")) else None
|
|
|
| center_arr = load_single_dicom_3ch(image_path_map[center_id], size=IMG_SIZE)
|
|
|
| if prev_id is not None and prev_id in image_path_map:
|
| prev_arr = load_single_dicom_3ch(image_path_map[prev_id], size=IMG_SIZE)
|
| else:
|
| prev_arr = center_arr
|
|
|
| if next_id is not None and next_id in image_path_map:
|
| next_arr = load_single_dicom_3ch(image_path_map[next_id], size=IMG_SIZE)
|
| else:
|
| next_arr = center_arr
|
|
|
| img_9ch = np.concatenate([prev_arr, center_arr, next_arr], axis=-1).astype(np.float32)
|
| img_9ch = (img_9ch - mean_9.reshape(1, 1, -1)) / (std_9.reshape(1, 1, -1) + 1e-7)
|
| return img_9ch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_model(
|
| backbone: str = BACKBONE,
|
| in_ch: int = IN_CHANNELS,
|
| n_cls: int = N_CLASSES,
|
| dropout: float = DROPOUT,
|
| drop_path: float = DROP_PATH,
|
| ) -> nn.Module:
|
| model = timm.create_model(
|
| backbone,
|
| pretrained=False,
|
| num_classes=0,
|
| drop_rate=dropout,
|
| drop_path_rate=drop_path,
|
| )
|
|
|
| old_conv = model.conv_stem
|
| new_conv = nn.Conv2d(
|
| in_ch,
|
| old_conv.out_channels,
|
| kernel_size=old_conv.kernel_size,
|
| stride=old_conv.stride,
|
| padding=old_conv.padding,
|
| bias=(old_conv.bias is not None),
|
| )
|
| k = max(in_ch // 3, 1)
|
| with torch.no_grad():
|
| new_conv.weight.copy_(old_conv.weight.repeat(1, k, 1, 1) / k)
|
| if old_conv.bias is not None:
|
| new_conv.bias.copy_(old_conv.bias)
|
| model.conv_stem = new_conv
|
|
|
| n_feat = model.num_features
|
| model.classifier = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(n_feat, n_cls))
|
| return model
|
|
|
|
|
| def _find_gradcam_target_layer(model: nn.Module) -> nn.Module:
|
|
|
| if hasattr(model, "conv_head") and isinstance(model.conv_head, nn.Module):
|
| return model.conv_head
|
| conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
|
| if not conv_layers:
|
| raise RuntimeError("No convolutional layer found for Grad-CAM target")
|
| return conv_layers[-1]
|
|
|
|
|
| class GradCAM:
|
| def __init__(self, model: nn.Module):
|
| self.model = model
|
| self.activations = None
|
| self.gradients = None
|
| target = _find_gradcam_target_layer(model)
|
| self._fh = target.register_forward_hook(self._forward_hook)
|
| self._bh = target.register_full_backward_hook(self._backward_hook)
|
|
|
| def _forward_hook(self, _module, _inputs, output):
|
| self.activations = output
|
|
|
| def _backward_hook(self, _module, _grad_input, grad_output):
|
| self.gradients = grad_output[0]
|
|
|
| def remove(self):
|
| self._fh.remove()
|
| self._bh.remove()
|
|
|
| def generate(self, input_tensor: torch.Tensor, class_idx: int = 0) -> Tuple[np.ndarray, np.ndarray]:
|
| self.model.zero_grad(set_to_none=True)
|
| with torch.enable_grad():
|
| output = self.model(input_tensor)
|
| target = output[:, class_idx].sum()
|
| target.backward()
|
|
|
| logits = output.squeeze(0).detach().cpu().numpy().astype(np.float32)
|
|
|
| if self.activations is None or self.gradients is None:
|
| cam = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| return logits, cam
|
|
|
| acts = self.activations.detach()
|
| grads = self.gradients.detach()
|
| weights = grads.mean(dim=(2, 3), keepdim=True)
|
| cam = torch.relu((weights * acts).sum(dim=1)).squeeze(0).cpu().numpy().astype(np.float32)
|
|
|
| if cam.size == 0 or float(cam.max()) <= 0.0:
|
| cam = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| else:
|
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
| return logits, cam
|
|
|
|
|
| def make_overlay(orig_rgb_u8: np.ndarray, cam: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| cam_r = cv2.resize(cam, (orig_rgb_u8.shape[1], orig_rgb_u8.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| heat_u8 = np.uint8(np.clip(cam_r, 0.0, 1.0) * 255.0)
|
| heat_bgr = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
|
| heat_rgb = cv2.cvtColor(heat_bgr, cv2.COLOR_BGR2RGB)
|
| return (alpha * heat_rgb + (1 - alpha) * orig_rgb_u8).astype(np.uint8)
|
|
|
|
|
| def load_models(device: str, fold_selection=None) -> Tuple[List[nn.Module], List[int]]:
|
| models = []
|
| loaded_folds: List[int] = []
|
|
|
| if fold_selection is None:
|
| fold_selection = FOLD_SELECTION
|
|
|
| if isinstance(fold_selection, str) and fold_selection.lower() == "ensemble":
|
| fold_indices = list(range(len(FOLD_MODEL_PATHS)))
|
| elif isinstance(fold_selection, int):
|
| fold_indices = [fold_selection]
|
| elif isinstance(fold_selection, str) and fold_selection.isdigit():
|
| fold_indices = [int(fold_selection)]
|
| else:
|
| raise ValueError('FOLD_SELECTION must be "ensemble" or an integer fold id (0..4).')
|
|
|
| for fold_idx in fold_indices:
|
| if fold_idx < 0 or fold_idx >= len(FOLD_MODEL_PATHS):
|
| print(f" β Invalid fold index: {fold_idx} (skipping)")
|
| continue
|
| path = FOLD_MODEL_PATHS[fold_idx]
|
| if not path.exists():
|
| print(f" β Missing fold checkpoint: {path.name} (skipping)")
|
| continue
|
| model = build_model()
|
| state = torch.load(str(path), map_location=device)
|
| model.load_state_dict(state, strict=True)
|
| model = model.to(device)
|
| model.eval()
|
| models.append(model)
|
| loaded_folds.append(fold_idx)
|
| return models, loaded_folds
|
|
|
|
|
| def sigmoid_np(x: np.ndarray) -> np.ndarray:
|
| return 1.0 / (1.0 + np.exp(-x))
|
|
|
|
|
| def apply_calibration(raw_logits: np.ndarray, calib_cfg: dict, iso_models) -> np.ndarray:
|
| best_method = calib_cfg.get("best_method", "temperature")
|
| temperature = float(calib_cfg.get("temperature", 1.0))
|
|
|
| if best_method == "isotonic" and iso_models is not None:
|
| raw_probs = sigmoid_np(raw_logits)
|
| cal_probs = np.zeros_like(raw_probs, dtype=np.float32)
|
| for i, subtype in enumerate(SUBTYPES):
|
| model_i = None
|
| if isinstance(iso_models, dict):
|
| model_i = iso_models.get(subtype)
|
| if model_i is None:
|
| model_i = iso_models.get(i)
|
| elif isinstance(iso_models, (list, tuple)) and i < len(iso_models):
|
| model_i = iso_models[i]
|
|
|
| if model_i is not None:
|
| cal_probs[i] = float(np.clip(model_i.predict([raw_probs[i]])[0], 0.0, 1.0))
|
| else:
|
| cal_probs[i] = float(raw_probs[i])
|
| return cal_probs
|
|
|
| return sigmoid_np(raw_logits / max(temperature, 1e-6)).astype(np.float32)
|
|
|
|
|
| def patient_aggregate(values: np.ndarray, method: str, topk: int) -> float:
|
| if len(values) == 0:
|
| return 0.0
|
| if method == "max":
|
| return float(np.max(values))
|
| if method == "mean":
|
| return float(np.mean(values))
|
| if method == "noisy_or":
|
| return float(1.0 - np.prod(1.0 - np.clip(values, 0.0, 1.0)))
|
| if method == "topk_mean":
|
| k = min(max(int(topk), 1), len(values))
|
| top_vals = np.sort(values)[-k:]
|
| return float(np.mean(top_vals))
|
| raise ValueError(f"Unknown PATIENT_AGG_METHOD: {method}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_slice_report(
|
| image_id: str,
|
| patient_id: str,
|
| probs: Dict[str, float],
|
| calib_cfg: dict,
|
| threshold: float,
|
| loaded_folds: List[int],
|
| report_image_path: Optional[str] = None,
|
| heatmap_path: Optional[str] = None,
|
| true_label: Optional[int] = None,
|
| ) -> dict:
|
| cal_any = probs["any"]
|
| high_thr = float(calib_cfg.get("triage_high_thresh", 0.7))
|
| low_thr = float(calib_cfg.get("triage_low_thresh", 0.3))
|
|
|
| if cal_any >= high_thr:
|
| band = "HIGH"
|
| elif cal_any >= low_thr:
|
| band = "MEDIUM"
|
| else:
|
| band = "LOW"
|
|
|
| is_positive = cal_any >= threshold
|
| outcome_key = "POSITIVE" if is_positive else "NEGATIVE"
|
|
|
| report = {
|
| "report_id": f"RPT_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}_{image_id[-8:]}",
|
| "generated_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
| "image_id": image_id,
|
| "patient_id": patient_id,
|
| "ground_truth_any": int(true_label) if true_label is not None else "N/A",
|
| "screening_module": {
|
| "version": "2.0",
|
| "architecture": BACKBONE,
|
| "input_type": "2.5D (9ch: prev+center+next)",
|
| "ensemble": "ensemble" if len(loaded_folds) > 1 else "single-fold",
|
| "folds_used": loaded_folds,
|
| "calibration_method": calib_cfg.get("best_method", "temperature"),
|
| },
|
| "prediction": {
|
| "screening_outcome": OUTCOME_POSITIVE if is_positive else OUTCOME_NEGATIVE,
|
| "decision_threshold_any": round(float(threshold), 6),
|
| "confidence_band": band,
|
| "confidence_band_label": BAND_LABELS[band],
|
| **{f"calibrated_prob_{k}": round(float(v), 6) for k, v in probs.items()},
|
| },
|
| "triage": {
|
| "action": TRIAGE_ACTIONS[(outcome_key, band)],
|
| "urgency": "URGENT" if (is_positive and band == "HIGH") else "STANDARD",
|
| },
|
| "disclaimer": DISCLAIMER,
|
| }
|
|
|
| if report_image_path or heatmap_path:
|
| report["explainability"] = {
|
| "method": "Gradient-weighted Class Activation Mapping (Grad-CAM)",
|
| "image_path": report_image_path,
|
| "heatmap_path": heatmap_path,
|
| "note": (
|
| "Highlighted regions indicate areas with greatest influence on the "
|
| "screening decision. These are not confirmed anatomical findings."
|
| ),
|
| }
|
|
|
| return report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| print("=" * 72)
|
| print(" ICH SCREENING β Improved 2.5D Inference")
|
| print("=" * 72)
|
|
|
| if not has_pydicom:
|
| print("ERROR: pydicom is not installed. Run: pip install pydicom")
|
| return
|
|
|
| if not DICOM_INPUT_DIR.exists():
|
| print(f"ERROR: DICOM input folder not found: {DICOM_INPUT_DIR}")
|
| print(" Create this folder and place .dcm files inside it.")
|
| return
|
|
|
| for path in [CALIB_PARAMS_PATH, NORM_STATS_PATH]:
|
| if not path.exists():
|
| print(f"ERROR: Required file missing: {path}")
|
| return
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| print(f"\n Device : {device}")
|
|
|
| with open(NORM_STATS_PATH, "r", encoding="utf-8") as f:
|
| norm = json.load(f)
|
| mean_9 = np.asarray(norm["mean_9ch"], dtype=np.float32)
|
| std_9 = np.asarray(norm["std_9ch"], dtype=np.float32)
|
|
|
| with open(CALIB_PARAMS_PATH, "r", encoding="utf-8") as f:
|
| calib_cfg = json.load(f)
|
|
|
| iso_models = None
|
| if ISOTONIC_MODELS_PATH.exists():
|
| with open(ISOTONIC_MODELS_PATH, "rb") as f:
|
| iso_models = pickle.load(f)
|
|
|
| threshold = (
|
| float(DECISION_THRESHOLD)
|
| if DECISION_THRESHOLD is not None
|
| else float(calib_cfg.get("threshold_at_spec90", 0.5))
|
| )
|
|
|
| print(f" Backbone : {BACKBONE}")
|
| print(f" Input : {IN_CHANNELS}ch @ {IMG_SIZE}x{IMG_SIZE}")
|
| print(f" Calibration : {calib_cfg.get('best_method', 'temperature')}")
|
| print(f" Decision threshold: {threshold:.6f}")
|
|
|
| models, loaded_folds = load_models(device, fold_selection=FOLD_SELECTION)
|
| if not models:
|
| print("ERROR: No fold checkpoints could be loaded.")
|
| return
|
| print(f" Fold models loaded: {len(models)} (folds: {loaded_folds})")
|
| gradcam_objects = [GradCAM(m) for m in models] if GENERATE_HEATMAPS else []
|
|
|
| adjacency_df = build_adjacency(DICOM_INPUT_DIR)
|
| if adjacency_df.empty:
|
| print(f"ERROR: No .dcm files found in {DICOM_INPUT_DIR}")
|
| return
|
|
|
| image_path_map = {
|
| Path(p).stem: Path(p)
|
| for p in adjacency_df["dcm_path"].tolist()
|
| }
|
|
|
| label_map: Dict[str, int] = {}
|
| if MANIFEST_PATH.exists():
|
| try:
|
| manifest = pd.read_csv(MANIFEST_PATH)
|
| if "image_id" in manifest.columns and "any" in manifest.columns:
|
| label_map = dict(zip(manifest["image_id"], manifest["any"]))
|
| print(f" Manifest labels : loaded {len(label_map)} rows")
|
| except Exception as exc:
|
| print(f" β Manifest load skipped: {exc}")
|
|
|
| reports_dir = OUTPUT_DIR / "reports"
|
| reports_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| print(f"\n{'β' * 72}")
|
| print(f" Processing {len(adjacency_df)} DICOM slices")
|
| print(f"{'β' * 72}\n")
|
|
|
| slice_rows = []
|
| report_summary_rows = []
|
| patient_probs: Dict[str, List[float]] = {}
|
|
|
| for i, row in adjacency_df.iterrows():
|
| image_id = row["image_id"]
|
| patient_id = row["patient_id"]
|
|
|
| try:
|
| img_9ch = build_9ch_for_row(row, image_path_map, mean_9=mean_9, std_9=std_9)
|
| except Exception as exc:
|
| print(f" [{i+1}/{len(adjacency_df)}] SKIP {image_id}: {exc}")
|
| continue
|
|
|
| tensor = torch.from_numpy(img_9ch).permute(2, 0, 1).unsqueeze(0).to(device)
|
|
|
| fold_logits = []
|
| fold_cams = []
|
| if GENERATE_HEATMAPS:
|
| for model, cam_obj in zip(models, gradcam_objects):
|
| logits, cam = cam_obj.generate(tensor, class_idx=0)
|
| fold_logits.append(logits)
|
| fold_cams.append(cam)
|
| else:
|
| with torch.no_grad():
|
| for model in models:
|
| logits = model(tensor).squeeze(0).detach().cpu().numpy().astype(np.float32)
|
| fold_logits.append(logits)
|
|
|
| mean_logits = np.mean(np.stack(fold_logits, axis=0), axis=0)
|
| raw_probs = sigmoid_np(mean_logits)
|
| cal_probs = apply_calibration(mean_logits, calib_cfg, iso_models)
|
|
|
| probs_dict = {name: float(cal_probs[j]) for j, name in enumerate(SUBTYPES)}
|
|
|
|
|
| preview_path = reports_dir / f"{image_id}_preview.png"
|
| heatmap_path = reports_dir / f"{image_id}_gradcam.png"
|
| try:
|
| center_rgb = load_single_dicom_3ch(Path(row["dcm_path"]), size=IMG_SIZE)
|
| center_rgb_u8 = (np.clip(center_rgb, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| cv2.imwrite(str(preview_path), cv2.cvtColor(center_rgb_u8, cv2.COLOR_RGB2BGR))
|
| if GENERATE_HEATMAPS:
|
| if fold_cams:
|
| mean_cam = np.mean(np.stack(fold_cams, axis=0), axis=0)
|
| else:
|
| mean_cam = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| overlay_rgb = make_overlay(center_rgb_u8, mean_cam, alpha=0.45)
|
| cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
|
| report_image_path = str(preview_path)
|
| report_heatmap_path = str(heatmap_path) if GENERATE_HEATMAPS else ""
|
| except Exception:
|
| report_image_path = ""
|
| report_heatmap_path = ""
|
|
|
| true_any = label_map.get(image_id)
|
| rep = build_slice_report(
|
| image_id=image_id,
|
| patient_id=patient_id,
|
| probs=probs_dict,
|
| calib_cfg=calib_cfg,
|
| threshold=threshold,
|
| loaded_folds=loaded_folds,
|
| report_image_path=report_image_path,
|
| heatmap_path=report_heatmap_path,
|
| true_label=int(true_any) if true_any is not None else None,
|
| )
|
|
|
| report_path = reports_dir / f"{image_id}_report.json"
|
| with open(report_path, "w", encoding="utf-8") as f:
|
| json.dump(rep, f, indent=2)
|
|
|
| slice_rows.append(
|
| {
|
| "image_id": image_id,
|
| "patient_id": patient_id,
|
| "true_any": int(true_any) if true_any is not None else "",
|
| "pred_any": int(probs_dict["any"] >= threshold),
|
| "cal_any": round(probs_dict["any"], 6),
|
| "raw_any": round(float(raw_probs[0]), 6),
|
| **{f"cal_{name}": round(float(probs_dict[name]), 6) for name in SUBTYPES[1:]},
|
| "confidence_band": rep["prediction"]["confidence_band"],
|
| "triage_action": rep["triage"]["action"],
|
| "urgency": rep["triage"]["urgency"],
|
| }
|
| )
|
|
|
| report_summary_rows.append(
|
| {
|
| "image_id": image_id,
|
| "true_label": int(true_any) if true_any is not None else "",
|
| "screening_outcome": rep["prediction"]["screening_outcome"],
|
| "raw_prob": round(float(raw_probs[0]), 6),
|
| "cal_prob": round(float(probs_dict["any"]), 6),
|
| "confidence_band": rep["prediction"]["confidence_band"],
|
| "triage_action": rep["triage"]["action"],
|
| "urgency": rep["triage"]["urgency"],
|
| "image_path": report_image_path,
|
| "heatmap_path": report_heatmap_path,
|
| }
|
| )
|
|
|
| patient_probs.setdefault(patient_id, []).append(probs_dict["any"])
|
|
|
| status = "[+] POS" if probs_dict["any"] >= threshold else "[-] NEG"
|
| print(
|
| f" [{i+1}/{len(adjacency_df)}] {image_id} β {status} "
|
| f"cal_any={probs_dict['any']:.4f}"
|
| )
|
|
|
| if not slice_rows:
|
| print("\nERROR: No slices were processed successfully.")
|
| return
|
|
|
| slice_df = pd.DataFrame(slice_rows)
|
| slice_csv_path = OUTPUT_DIR / "slice_predictions.csv"
|
| slice_df.to_csv(slice_csv_path, index=False)
|
|
|
| report_summary_df = pd.DataFrame(report_summary_rows)
|
| report_summary_csv_path = OUTPUT_DIR / "report_summary.csv"
|
| report_summary_df.to_csv(report_summary_csv_path, index=False)
|
|
|
| patient_rows = []
|
| for pid, vals in patient_probs.items():
|
| arr = np.asarray(vals, dtype=np.float32)
|
| agg_prob = patient_aggregate(arr, PATIENT_AGG_METHOD, PATIENT_TOPK)
|
| patient_rows.append(
|
| {
|
| "patient_id": pid,
|
| "n_slices": int(len(arr)),
|
| "agg_method": PATIENT_AGG_METHOD,
|
| "agg_any_probability": round(float(agg_prob), 6),
|
| "pred_any": int(agg_prob >= threshold),
|
| }
|
| )
|
|
|
| patient_df = pd.DataFrame(patient_rows)
|
| patient_csv_path = OUTPUT_DIR / "patient_predictions.csv"
|
| patient_df.to_csv(patient_csv_path, index=False)
|
|
|
| for cam_obj in gradcam_objects:
|
| cam_obj.remove()
|
|
|
| n_pos = int((slice_df["pred_any"] == 1).sum())
|
| n_total = len(slice_df)
|
| n_urgent = int((slice_df["urgency"] == "URGENT").sum())
|
|
|
| print(f"\n{'β' * 72}")
|
| print(" INFERENCE COMPLETE")
|
| print(f"{'β' * 72}")
|
| print(f" Slices processed : {n_total}")
|
| print(f" Positive slices : {n_pos}")
|
| print(f" Urgent escalations : {n_urgent}")
|
| print(f" Patients processed : {len(patient_df)}")
|
| print("\n Outputs:")
|
| print(f" JSON reports : {reports_dir}")
|
| print(f" Report images : {reports_dir}")
|
| print(f" Report summary : {report_summary_csv_path}")
|
| print(f" Slice CSV : {slice_csv_path}")
|
| print(f" Patient CSV : {patient_csv_path}")
|
| print(f"{'β' * 72}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|