Harshit Ghosh commited on
Commit
ea664f8
·
1 Parent(s): 410e48e

making changes for huggingface

Browse files
.gitignore CHANGED
@@ -45,7 +45,7 @@ download_imp/*.pt
45
  download_imp/*.pkl
46
  download_imp/*.onnx
47
 
48
- download_imp/*
49
  # Local downloaded artifacts
50
  download/
51
  # Local downloaded data
 
45
  download_imp/*.pkl
46
  download_imp/*.onnx
47
 
48
+ # download_imp/*
49
  # Local downloaded artifacts
50
  download/
51
  # Local downloaded data
download_imp/INFERENCE_USAGE.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Improved Model Inference — Usage Guide (`download_imp`)
2
+
3
+ ## What this runner does
4
+
5
+ `run_inference.py` runs the **latest improved model** trained from the notebooks in `improvement/`:
6
+
7
+ - EfficientNet-B4 backbone (`tf_efficientnet_b4`)
8
+ - 2.5D input (prev + center + next slices → 9 channels)
9
+ - 6 outputs (`any` + 5 hemorrhage subtypes)
10
+ - 5-fold ensemble (`best_model_fold0..4.pth`)
11
+ - Saved calibration (`isotonic`/`temperature`) from `calibration_params.json`
12
+
13
+ Outputs:
14
+
15
+ - Per-slice JSON report (`outputs/reports/*.json`)
16
+ - Slice-level CSV (`outputs/slice_predictions.csv`)
17
+ - Patient-level CSV (`outputs/patient_predictions.csv`)
18
+
19
+ ---
20
+
21
+ ## Required files (already in `download_imp/`)
22
+
23
+ - `best_model_fold0.pth`
24
+ - `best_model_fold1.pth`
25
+ - `best_model_fold2.pth`
26
+ - `best_model_fold3.pth`
27
+ - `best_model_fold4.pth`
28
+ - `calibration_params.json`
29
+ - `isotonic_models.pkl`
30
+ - `normalization_stats.json`
31
+ - `manifest.csv` (optional at inference time; used only for `true_any` if IDs match)
32
+
33
+ ---
34
+
35
+ ## Python package requirements
36
+
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ Notes:
42
+
43
+ - `timm` is required for `tf_efficientnet_b4` model construction.
44
+ - `scikit-learn` is needed to deserialize and use `isotonic_models.pkl`.
45
+
46
+ ---
47
+
48
+ ## Folder setup
49
+
50
+ Create this folder and place DICOM files there:
51
+
52
+ ```text
53
+ download_imp/
54
+ ├── run_inference.py
55
+ ├── run_interface.py
56
+ ├── best_model_fold0.pth
57
+ ├── best_model_fold1.pth
58
+ ├── best_model_fold2.pth
59
+ ├── best_model_fold3.pth
60
+ ├── best_model_fold4.pth
61
+ ├── calibration_params.json
62
+ ├── isotonic_models.pkl
63
+ ├── normalization_stats.json
64
+ ├── manifest.csv
65
+ └── dicom_inputs/
66
+ ├── ID_xxx1.dcm
67
+ ├── ID_xxx2.dcm
68
+ └── ...
69
+ ```
70
+
71
+ ---
72
+
73
+ ## Run commands
74
+
75
+ From workspace root:
76
+
77
+ ```bash
78
+ cd download_imp
79
+ python run_inference.py
80
+ ```
81
+
82
+ or (same thing):
83
+
84
+ ```bash
85
+ python run_interface.py
86
+ ```
87
+
88
+ ---
89
+
90
+ ## Important behavior
91
+
92
+ - No CLI arguments; all settings are at top of `run_inference.py` (`CONFIG` section).
93
+ - `FOLD_SELECTION` controls checkpoint selection:
94
+ - `"ensemble"` = use all available folds and average logits
95
+ - `0..4` = use one specific fold only
96
+ - If `best_method` is `isotonic`, the runner uses `isotonic_models.pkl`.
97
+ - Missing prev/next slice in a series is handled exactly like training cache logic: neighbor falls back to center slice.
98
+ - Decision threshold defaults to `threshold_at_spec90` from `calibration_params.json` unless overridden in config.
99
+
100
+ ---
101
+
102
+ ## Recommended production checklist
103
+
104
+ 1. Keep all fold checkpoints and calibration files in the same `download_imp/` directory.
105
+ 2. Verify DICOMs are non-contrast head CT slices before inference.
106
+ 3. Run once on a small sample and review `slice_predictions.csv` and JSON reports.
107
+ 4. Have radiologist review all flagged and uncertain cases.
download_imp/calibration_params.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_method": "isotonic",
3
+ "temperature": 0.7681182881076687,
4
+ "ece_raw": 0.029524699832706974,
5
+ "ece_temp": 0.008160804909196835,
6
+ "ece_isotonic": 6.68228018673829e-18,
7
+ "brier_raw": 0.05208174680122169,
8
+ "brier_temp": 0.0514404718775377,
9
+ "brier_isotonic": 0.050937655679539166,
10
+ "triage_high_thresh": 0.7,
11
+ "triage_low_thresh": 0.3,
12
+ "oof_auc_any": 0.9501744154042578,
13
+ "sensitivity_at_spec90": 0.8556471787269526,
14
+ "specificity_at_sens95": 0.7441809977204707,
15
+ "threshold_at_spec90": 0.16216216216216217,
16
+ "threshold_at_sens95": 0.04074505238649592
17
+ }
download_imp/manifest.csv ADDED
The diff for this file is too large to render. See raw diff
 
download_imp/normalization_stats.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mean_3ch": [
3
+ 0.162136,
4
+ 0.141483,
5
+ 0.183675
6
+ ],
7
+ "std_3ch": [
8
+ 0.312067,
9
+ 0.283885,
10
+ 0.305968
11
+ ],
12
+ "mean_9ch": [
13
+ 0.162136,
14
+ 0.141483,
15
+ 0.183675,
16
+ 0.162136,
17
+ 0.141483,
18
+ 0.183675,
19
+ 0.162136,
20
+ 0.141483,
21
+ 0.183675
22
+ ],
23
+ "std_9ch": [
24
+ 0.312067,
25
+ 0.283885,
26
+ 0.305968,
27
+ 0.312067,
28
+ 0.283885,
29
+ 0.305968,
30
+ 0.312067,
31
+ 0.283885,
32
+ 0.305968
33
+ ],
34
+ "n_sample_images": 5000,
35
+ "n_pixels": 722000000,
36
+ "img_size": 380,
37
+ "note": "Stats computed on windowed [0,1] images. Pre-normalization applied in cache."
38
+ }
download_imp/run_inference.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone Inference Script — Improved ICH Screening (2.5D, 5-fold Ensemble)
3
+ =============================================================================
4
+ Reads raw DICOM CT brain slices, reproduces improved preprocessing from the
5
+ improvement notebooks, runs 5-fold EfficientNet-B4 ensemble inference, applies
6
+ saved calibration, and generates:
7
+ • Per-image JSON reports (fixed schema)
8
+ • Slice-level CSV summary
9
+ • Patient-level CSV summary
10
+
11
+ No command-line arguments — all paths are configured in the CONFIG section.
12
+
13
+ Requirements:
14
+ pip install torch timm pydicom opencv-python-headless numpy pandas scikit-learn
15
+
16
+ Usage:
17
+ python run_inference.py
18
+ """
19
+
20
+ import datetime
21
+ import json
22
+ import pickle
23
+ import warnings
24
+ from pathlib import Path
25
+ from typing import Dict, List, Optional, Tuple
26
+ from zoneinfo import ZoneInfo
27
+
28
+ import cv2
29
+ import numpy as np
30
+ import pandas as pd
31
+ import torch
32
+ import torch.nn as nn
33
+ import timm
34
+
35
+ # Try importing pydicom — needed for DICOM input mode
36
+ has_pydicom = False
37
+ pydicom = None
38
+ try:
39
+ import pydicom
40
+ import pydicom.multival
41
+
42
+ # Some anonymized datasets contain non-standard UID strings like "ID_xxx".
43
+ # Ignore only this known noisy warning from pydicom.
44
+ warnings.filterwarnings(
45
+ "ignore",
46
+ message=r"Invalid value for VR UI:",
47
+ category=UserWarning,
48
+ module=r"pydicom\.valuerep",
49
+ )
50
+ has_pydicom = True
51
+ except ImportError:
52
+ pass
53
+
54
+ # ══════════════════════════════════════════════════════════════════════════
55
+ # CONFIG — edit these paths before running
56
+ # ══════════════════════════════════════════════════════════════════════════
57
+
58
+ SCRIPT_DIR = Path(__file__).resolve().parent
59
+
60
+ # Model artifacts (required)
61
+ FOLD_MODEL_PATHS = [SCRIPT_DIR / f"best_model_fold{i}.pth" for i in range(5)]
62
+ CALIB_PARAMS_PATH = SCRIPT_DIR / "calibration_params.json"
63
+ ISOTONIC_MODELS_PATH = SCRIPT_DIR / "isotonic_models.pkl"
64
+ NORM_STATS_PATH = SCRIPT_DIR / "normalization_stats.json"
65
+
66
+ # Input — folder containing .dcm files
67
+ DICOM_INPUT_DIR = Path(r"D:\major8thsem\stage_2_test")
68
+
69
+ # Optional labels (only for quick validation against known RSNA IDs)
70
+ MANIFEST_PATH = SCRIPT_DIR / "manifest.csv"
71
+
72
+ # Output
73
+ OUTPUT_DIR = SCRIPT_DIR / "outputs"
74
+
75
+ # Architecture constants (must match training notebooks)
76
+ BACKBONE = "tf_efficientnet_b4"
77
+ IMG_SIZE = 380
78
+ IN_CHANNELS = 9
79
+ N_CLASSES = 6
80
+ DROPOUT = 0.4
81
+ DROP_PATH = 0.2
82
+
83
+ # Output / triage behavior
84
+ PATIENT_AGG_METHOD = "topk_mean" # one of: max, mean, noisy_or, topk_mean
85
+ PATIENT_TOPK = 3
86
+ DECISION_THRESHOLD = None # None -> use threshold_at_spec90 from calibration JSON
87
+ FOLD_SELECTION = "ensemble" # "ensemble" or an integer fold id: 0..4
88
+ GENERATE_HEATMAPS = True
89
+
90
+ WINDOWS = [
91
+ (40, 80), # brain
92
+ (75, 215), # subdural
93
+ (40, 380), # soft tissue / bone
94
+ ]
95
+
96
+ SUBTYPES = [
97
+ "any",
98
+ "epidural",
99
+ "intraparenchymal",
100
+ "intraventricular",
101
+ "subarachnoid",
102
+ "subdural",
103
+ ]
104
+
105
+ OUTCOME_POSITIVE = "Hemorrhage indicator detected"
106
+ OUTCOME_NEGATIVE = "No hemorrhage indicator detected"
107
+
108
+ BAND_LABELS = {
109
+ "HIGH": "High confidence",
110
+ "MEDIUM": "Moderate confidence",
111
+ "LOW": "Low confidence",
112
+ }
113
+
114
+ TRIAGE_ACTIONS = {
115
+ ("POSITIVE", "HIGH"): "Urgent radiologist review recommended",
116
+ ("POSITIVE", "MEDIUM"): "Prioritised radiologist review recommended",
117
+ ("POSITIVE", "LOW"): "Radiologist review recommended — low confidence",
118
+ ("NEGATIVE", "HIGH"): "Standard workflow — no urgent action",
119
+ ("NEGATIVE", "MEDIUM"): "Standard workflow — manual review if clinically indicated",
120
+ ("NEGATIVE", "LOW"): "Manual review recommended — model uncertainty high",
121
+ }
122
+
123
+ DISCLAIMER = (
124
+ "This report is produced by an AI-assisted screening tool and does NOT "
125
+ "constitute a medical diagnosis. All screening findings must be reviewed "
126
+ "and confirmed by a qualified, licensed medical professional before any "
127
+ "clinical decision is made. The system is intended solely as a "
128
+ "decision-support aid in a screening workflow and is not cleared for "
129
+ "standalone diagnostic use."
130
+ )
131
+
132
+ IST = ZoneInfo("Asia/Kolkata")
133
+
134
+
135
+ # ══════════════════════════════════════════════════════════════════════════
136
+ # DICOM + PREPROCESSING
137
+ # ══════════════════════════════════════════════════════════════════════════
138
+
139
+
140
+ def _to_scalar(val) -> float:
141
+ if has_pydicom and isinstance(val, (list, pydicom.multival.MultiValue)):
142
+ return float(val[0])
143
+ return float(val)
144
+
145
+
146
+ def apply_window(img_hu: np.ndarray, wc: float, ww: float) -> np.ndarray:
147
+ lo = wc - ww / 2
148
+ hi = wc + ww / 2
149
+ return np.clip((img_hu - lo) / (hi - lo), 0.0, 1.0)
150
+
151
+
152
+ def load_single_dicom_3ch(dcm_path: Path, size: int = IMG_SIZE) -> np.ndarray:
153
+ if not has_pydicom or pydicom is None:
154
+ raise RuntimeError("pydicom is not installed. Run: pip install pydicom")
155
+ dcm = pydicom.dcmread(str(dcm_path))
156
+ img = dcm.pixel_array.astype(np.float32)
157
+
158
+ slope = _to_scalar(getattr(dcm, "RescaleSlope", 1))
159
+ inter = _to_scalar(getattr(dcm, "RescaleIntercept", 0))
160
+ img = img * slope + inter
161
+
162
+ channels = []
163
+ for wc, ww in WINDOWS:
164
+ ch = apply_window(img, wc, ww)
165
+ ch = cv2.resize(ch, (size, size), interpolation=cv2.INTER_AREA)
166
+ channels.append(ch)
167
+
168
+ return np.stack(channels, axis=-1).astype(np.float32) # (H, W, 3) in [0,1]
169
+
170
+
171
+ def build_adjacency(dicom_dir: Path) -> pd.DataFrame:
172
+ if not has_pydicom or pydicom is None:
173
+ raise RuntimeError("pydicom is not installed. Run: pip install pydicom")
174
+ records: List[dict] = []
175
+ for dcm_path in sorted(dicom_dir.glob("*.dcm")):
176
+ image_id = dcm_path.stem
177
+ try:
178
+ dcm = pydicom.dcmread(str(dcm_path), stop_before_pixels=True)
179
+ patient_id = str(getattr(dcm, "PatientID", "UNKNOWN"))
180
+ series_uid = str(getattr(dcm, "SeriesInstanceUID", "UNKNOWN_SERIES"))
181
+
182
+ ipp = getattr(dcm, "ImagePositionPatient", None)
183
+ if ipp is not None and len(ipp) >= 3:
184
+ z_pos = float(ipp[2])
185
+ else:
186
+ z_pos = float(getattr(dcm, "SliceLocation", 0.0))
187
+ except Exception:
188
+ patient_id = "UNKNOWN"
189
+ series_uid = "UNKNOWN_SERIES"
190
+ z_pos = 0.0
191
+
192
+ records.append(
193
+ {
194
+ "image_id": image_id,
195
+ "patient_id": patient_id,
196
+ "series_uid": series_uid,
197
+ "z_pos": z_pos,
198
+ "dcm_path": str(dcm_path),
199
+ }
200
+ )
201
+
202
+ if not records:
203
+ return pd.DataFrame(columns=["image_id", "patient_id", "series_uid", "z_pos", "dcm_path", "prev_image_id", "next_image_id"])
204
+
205
+ df = pd.DataFrame(records)
206
+ df = df.sort_values(["patient_id", "series_uid", "z_pos"]).reset_index(drop=True)
207
+ df["prev_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(1)
208
+ df["next_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(-1)
209
+ return df
210
+
211
+
212
+ def build_9ch_for_row(row: pd.Series, image_path_map: Dict[str, Path], mean_9: np.ndarray, std_9: np.ndarray) -> np.ndarray:
213
+ center_id = row["image_id"]
214
+ prev_id = row["prev_image_id"] if pd.notna(row.get("prev_image_id")) else None
215
+ next_id = row["next_image_id"] if pd.notna(row.get("next_image_id")) else None
216
+
217
+ center_arr = load_single_dicom_3ch(image_path_map[center_id], size=IMG_SIZE)
218
+
219
+ if prev_id is not None and prev_id in image_path_map:
220
+ prev_arr = load_single_dicom_3ch(image_path_map[prev_id], size=IMG_SIZE)
221
+ else:
222
+ prev_arr = center_arr
223
+
224
+ if next_id is not None and next_id in image_path_map:
225
+ next_arr = load_single_dicom_3ch(image_path_map[next_id], size=IMG_SIZE)
226
+ else:
227
+ next_arr = center_arr
228
+
229
+ img_9ch = np.concatenate([prev_arr, center_arr, next_arr], axis=-1).astype(np.float32)
230
+ img_9ch = (img_9ch - mean_9.reshape(1, 1, -1)) / (std_9.reshape(1, 1, -1) + 1e-7)
231
+ return img_9ch
232
+
233
+
234
+ # ══════════════════════════════════════════════════════════════════════════
235
+ # MODEL + CALIBRATION
236
+ # ══════════════════════════════════════════════════════════════════════════
237
+
238
+
239
+ def build_model(
240
+ backbone: str = BACKBONE,
241
+ in_ch: int = IN_CHANNELS,
242
+ n_cls: int = N_CLASSES,
243
+ dropout: float = DROPOUT,
244
+ drop_path: float = DROP_PATH,
245
+ ) -> nn.Module:
246
+ model = timm.create_model(
247
+ backbone,
248
+ pretrained=False,
249
+ num_classes=0,
250
+ drop_rate=dropout,
251
+ drop_path_rate=drop_path,
252
+ )
253
+
254
+ old_conv = model.conv_stem
255
+ new_conv = nn.Conv2d(
256
+ in_ch,
257
+ old_conv.out_channels,
258
+ kernel_size=old_conv.kernel_size,
259
+ stride=old_conv.stride,
260
+ padding=old_conv.padding,
261
+ bias=(old_conv.bias is not None),
262
+ )
263
+ k = max(in_ch // 3, 1)
264
+ with torch.no_grad():
265
+ new_conv.weight.copy_(old_conv.weight.repeat(1, k, 1, 1) / k)
266
+ if old_conv.bias is not None:
267
+ new_conv.bias.copy_(old_conv.bias)
268
+ model.conv_stem = new_conv
269
+
270
+ n_feat = model.num_features
271
+ model.classifier = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(n_feat, n_cls))
272
+ return model
273
+
274
+
275
+ def _find_gradcam_target_layer(model: nn.Module) -> nn.Module:
276
+ # Prefer the last semantic convolutional stage for EfficientNet-like models.
277
+ if hasattr(model, "conv_head") and isinstance(model.conv_head, nn.Module):
278
+ return model.conv_head
279
+ conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
280
+ if not conv_layers:
281
+ raise RuntimeError("No convolutional layer found for Grad-CAM target")
282
+ return conv_layers[-1]
283
+
284
+
285
+ class GradCAM:
286
+ def __init__(self, model: nn.Module):
287
+ self.model = model
288
+ self.activations = None
289
+ self.gradients = None
290
+ target = _find_gradcam_target_layer(model)
291
+ self._fh = target.register_forward_hook(self._forward_hook)
292
+ self._bh = target.register_full_backward_hook(self._backward_hook)
293
+
294
+ def _forward_hook(self, _module, _inputs, output):
295
+ self.activations = output
296
+
297
+ def _backward_hook(self, _module, _grad_input, grad_output):
298
+ self.gradients = grad_output[0]
299
+
300
+ def remove(self):
301
+ self._fh.remove()
302
+ self._bh.remove()
303
+
304
+ def generate(self, input_tensor: torch.Tensor, class_idx: int = 0) -> Tuple[np.ndarray, np.ndarray]:
305
+ self.model.zero_grad(set_to_none=True)
306
+ use_amp = bool(input_tensor.is_cuda)
307
+ with torch.enable_grad():
308
+ with torch.cuda.amp.autocast(enabled=use_amp):
309
+ output = self.model(input_tensor)
310
+ target = output[:, class_idx].sum()
311
+ target.backward()
312
+
313
+ logits = output.detach().cpu().numpy().astype(np.float32)
314
+ if logits.ndim == 1:
315
+ logits = logits[None, :]
316
+
317
+ if self.activations is None or self.gradients is None:
318
+ cam = np.zeros((logits.shape[0], IMG_SIZE, IMG_SIZE), dtype=np.float32)
319
+ return (logits[0], cam[0]) if logits.shape[0] == 1 else (logits, cam)
320
+
321
+ acts = self.activations.detach()
322
+ grads = self.gradients.detach()
323
+ weights = grads.mean(dim=(2, 3), keepdim=True)
324
+ cam = torch.relu((weights * acts).sum(dim=1)).cpu().numpy().astype(np.float32)
325
+ if cam.ndim == 2:
326
+ cam = cam[None, ...]
327
+
328
+ for idx in range(cam.shape[0]):
329
+ if cam[idx].size == 0 or float(cam[idx].max()) <= 0.0:
330
+ cam[idx] = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
331
+ else:
332
+ cam[idx] = (cam[idx] - cam[idx].min()) / (cam[idx].max() - cam[idx].min() + 1e-8)
333
+
334
+ return (logits[0], cam[0]) if logits.shape[0] == 1 else (logits, cam)
335
+
336
+
337
+ def make_overlay(orig_rgb_u8: np.ndarray, cam: np.ndarray, alpha: float = 0.45) -> np.ndarray:
338
+ cam_r = cv2.resize(cam, (orig_rgb_u8.shape[1], orig_rgb_u8.shape[0]), interpolation=cv2.INTER_LINEAR)
339
+ heat_u8 = np.uint8(np.clip(cam_r, 0.0, 1.0) * 255.0)
340
+ heat_bgr = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
341
+ heat_rgb = cv2.cvtColor(heat_bgr, cv2.COLOR_BGR2RGB)
342
+ return (alpha * heat_rgb + (1 - alpha) * orig_rgb_u8).astype(np.uint8)
343
+
344
+
345
+ def load_models(device: str, fold_selection=None) -> Tuple[List[nn.Module], List[int]]:
346
+ models = []
347
+ loaded_folds: List[int] = []
348
+
349
+ if fold_selection is None:
350
+ fold_selection = FOLD_SELECTION
351
+
352
+ if isinstance(fold_selection, str) and fold_selection.lower() == "ensemble":
353
+ fold_indices = list(range(len(FOLD_MODEL_PATHS)))
354
+ elif isinstance(fold_selection, int):
355
+ fold_indices = [fold_selection]
356
+ elif isinstance(fold_selection, str) and fold_selection.isdigit():
357
+ fold_indices = [int(fold_selection)]
358
+ else:
359
+ raise ValueError('FOLD_SELECTION must be "ensemble" or an integer fold id (0..4).')
360
+
361
+ for fold_idx in fold_indices:
362
+ if fold_idx < 0 or fold_idx >= len(FOLD_MODEL_PATHS):
363
+ print(f" ⚠ Invalid fold index: {fold_idx} (skipping)")
364
+ continue
365
+ path = FOLD_MODEL_PATHS[fold_idx]
366
+ if not path.exists():
367
+ print(f" ⚠ Missing fold checkpoint: {path.name} (skipping)")
368
+ continue
369
+ model = build_model()
370
+ state = torch.load(str(path), map_location=device)
371
+ model.load_state_dict(state, strict=True)
372
+ model = model.to(device)
373
+ model.eval()
374
+ models.append(model)
375
+ loaded_folds.append(fold_idx)
376
+ return models, loaded_folds
377
+
378
+
379
+ def sigmoid_np(x: np.ndarray) -> np.ndarray:
380
+ return 1.0 / (1.0 + np.exp(-x))
381
+
382
+
383
+ def apply_calibration(raw_logits: np.ndarray, calib_cfg: dict, iso_models) -> np.ndarray:
384
+ best_method = calib_cfg.get("best_method", "temperature")
385
+ temperature = float(calib_cfg.get("temperature", 1.0))
386
+
387
+ if best_method == "isotonic" and iso_models is not None:
388
+ raw_probs = sigmoid_np(raw_logits)
389
+ cal_probs = np.zeros_like(raw_probs, dtype=np.float32)
390
+ for i, subtype in enumerate(SUBTYPES):
391
+ model_i = None
392
+ if isinstance(iso_models, dict):
393
+ model_i = iso_models.get(subtype)
394
+ if model_i is None:
395
+ model_i = iso_models.get(i)
396
+ elif isinstance(iso_models, (list, tuple)) and i < len(iso_models):
397
+ model_i = iso_models[i]
398
+
399
+ if model_i is not None:
400
+ cal_probs[i] = float(np.clip(model_i.predict([raw_probs[i]])[0], 0.0, 1.0))
401
+ else:
402
+ cal_probs[i] = float(raw_probs[i])
403
+ return cal_probs
404
+
405
+ return sigmoid_np(raw_logits / max(temperature, 1e-6)).astype(np.float32)
406
+
407
+
408
+ def patient_aggregate(values: np.ndarray, method: str, topk: int) -> float:
409
+ if len(values) == 0:
410
+ return 0.0
411
+ if method == "max":
412
+ return float(np.max(values))
413
+ if method == "mean":
414
+ return float(np.mean(values))
415
+ if method == "noisy_or":
416
+ return float(1.0 - np.prod(1.0 - np.clip(values, 0.0, 1.0)))
417
+ if method == "topk_mean":
418
+ k = min(max(int(topk), 1), len(values))
419
+ top_vals = np.sort(values)[-k:]
420
+ return float(np.mean(top_vals))
421
+ raise ValueError(f"Unknown PATIENT_AGG_METHOD: {method}")
422
+
423
+
424
+ # ══════════════════════════════════════════════════════════════════════════
425
+ # REPORT HELPERS
426
+ # ══════════════════════════════════════════════════════════════════════════
427
+
428
+
429
+ def build_slice_report(
430
+ image_id: str,
431
+ patient_id: str,
432
+ probs: Dict[str, float],
433
+ calib_cfg: dict,
434
+ threshold: float,
435
+ loaded_folds: List[int],
436
+ report_image_path: Optional[str] = None,
437
+ heatmap_path: Optional[str] = None,
438
+ true_label: Optional[int] = None,
439
+ ) -> dict:
440
+ cal_any = probs["any"]
441
+ high_thr = float(calib_cfg.get("triage_high_thresh", 0.7))
442
+ low_thr = float(calib_cfg.get("triage_low_thresh", 0.3))
443
+
444
+ if cal_any >= high_thr:
445
+ band = "HIGH"
446
+ elif cal_any >= low_thr:
447
+ band = "MEDIUM"
448
+ else:
449
+ band = "LOW"
450
+
451
+ is_positive = cal_any >= threshold
452
+ outcome_key = "POSITIVE" if is_positive else "NEGATIVE"
453
+
454
+ now_ist = datetime.datetime.now(IST)
455
+ report = {
456
+ "report_id": f"RPT_{now_ist.strftime('%Y%m%d_%H%M%S')}_{image_id[-8:]}",
457
+ "generated_at": now_ist.isoformat(),
458
+ "image_id": image_id,
459
+ "patient_id": patient_id,
460
+ "ground_truth_any": int(true_label) if true_label is not None else "N/A",
461
+ "screening_module": {
462
+ "version": "2.0",
463
+ "architecture": BACKBONE,
464
+ "input_type": "2.5D (9ch: prev+center+next)",
465
+ "ensemble": "ensemble" if len(loaded_folds) > 1 else "single-fold",
466
+ "folds_used": loaded_folds,
467
+ "calibration_method": calib_cfg.get("best_method", "temperature"),
468
+ },
469
+ "prediction": {
470
+ "screening_outcome": OUTCOME_POSITIVE if is_positive else OUTCOME_NEGATIVE,
471
+ "decision_threshold_any": round(float(threshold), 6),
472
+ "confidence_band": band,
473
+ "confidence_band_label": BAND_LABELS[band],
474
+ **{f"calibrated_prob_{k}": round(float(v), 6) for k, v in probs.items()},
475
+ },
476
+ "triage": {
477
+ "action": TRIAGE_ACTIONS[(outcome_key, band)],
478
+ "urgency": "URGENT" if (is_positive and band == "HIGH") else "STANDARD",
479
+ },
480
+ "disclaimer": DISCLAIMER,
481
+ }
482
+
483
+ if report_image_path or heatmap_path:
484
+ report["explainability"] = {
485
+ "method": "Gradient-weighted Class Activation Mapping (Grad-CAM)",
486
+ "image_path": report_image_path,
487
+ "heatmap_path": heatmap_path,
488
+ "note": (
489
+ "Highlighted regions indicate areas with greatest influence on the "
490
+ "screening decision. These are not confirmed anatomical findings."
491
+ ),
492
+ }
493
+
494
+ return report
495
+
496
+
497
+ # ══════════════════════════════════════════════════════════════════════════
498
+ # MAIN
499
+ # ══════════════════════════════════════════════════════════════════════════
500
+
501
+
502
+ def main():
503
+ print("=" * 72)
504
+ print(" ICH SCREENING — Improved 2.5D Inference")
505
+ print("=" * 72)
506
+
507
+ if not has_pydicom:
508
+ print("ERROR: pydicom is not installed. Run: pip install pydicom")
509
+ return
510
+
511
+ if not DICOM_INPUT_DIR.exists():
512
+ print(f"ERROR: DICOM input folder not found: {DICOM_INPUT_DIR}")
513
+ print(" Create this folder and place .dcm files inside it.")
514
+ return
515
+
516
+ for path in [CALIB_PARAMS_PATH, NORM_STATS_PATH]:
517
+ if not path.exists():
518
+ print(f"ERROR: Required file missing: {path}")
519
+ return
520
+
521
+ device = "cuda" if torch.cuda.is_available() else "cpu"
522
+ print(f"\n Device : {device}")
523
+
524
+ with open(NORM_STATS_PATH, "r", encoding="utf-8") as f:
525
+ norm = json.load(f)
526
+ mean_9 = np.asarray(norm["mean_9ch"], dtype=np.float32)
527
+ std_9 = np.asarray(norm["std_9ch"], dtype=np.float32)
528
+
529
+ with open(CALIB_PARAMS_PATH, "r", encoding="utf-8") as f:
530
+ calib_cfg = json.load(f)
531
+
532
+ iso_models = None
533
+ if ISOTONIC_MODELS_PATH.exists():
534
+ with open(ISOTONIC_MODELS_PATH, "rb") as f:
535
+ iso_models = pickle.load(f)
536
+
537
+ threshold = (
538
+ float(DECISION_THRESHOLD)
539
+ if DECISION_THRESHOLD is not None
540
+ else float(calib_cfg.get("threshold_at_spec90", 0.5))
541
+ )
542
+
543
+ print(f" Backbone : {BACKBONE}")
544
+ print(f" Input : {IN_CHANNELS}ch @ {IMG_SIZE}x{IMG_SIZE}")
545
+ print(f" Calibration : {calib_cfg.get('best_method', 'temperature')}")
546
+ print(f" Decision threshold: {threshold:.6f}")
547
+
548
+ models, loaded_folds = load_models(device, fold_selection=FOLD_SELECTION)
549
+ if not models:
550
+ print("ERROR: No fold checkpoints could be loaded.")
551
+ return
552
+ print(f" Fold models loaded: {len(models)} (folds: {loaded_folds})")
553
+ gradcam_objects = [GradCAM(m) for m in models] if GENERATE_HEATMAPS else []
554
+
555
+ adjacency_df = build_adjacency(DICOM_INPUT_DIR)
556
+ if adjacency_df.empty:
557
+ print(f"ERROR: No .dcm files found in {DICOM_INPUT_DIR}")
558
+ return
559
+
560
+ image_path_map = {
561
+ Path(p).stem: Path(p)
562
+ for p in adjacency_df["dcm_path"].tolist()
563
+ }
564
+
565
+ label_map: Dict[str, int] = {}
566
+ if MANIFEST_PATH.exists():
567
+ try:
568
+ manifest = pd.read_csv(MANIFEST_PATH)
569
+ if "image_id" in manifest.columns and "any" in manifest.columns:
570
+ label_map = dict(zip(manifest["image_id"], manifest["any"]))
571
+ print(f" Manifest labels : loaded {len(label_map)} rows")
572
+ except Exception as exc:
573
+ print(f" ⚠ Manifest load skipped: {exc}")
574
+
575
+ reports_dir = OUTPUT_DIR / "reports"
576
+ reports_dir.mkdir(parents=True, exist_ok=True)
577
+
578
+ print(f"\n{'─' * 72}")
579
+ print(f" Processing {len(adjacency_df)} DICOM slices")
580
+ print(f"{'─' * 72}\n")
581
+
582
+ slice_rows = []
583
+ report_summary_rows = []
584
+ patient_probs: Dict[str, List[float]] = {}
585
+
586
+ for i, row in adjacency_df.iterrows():
587
+ image_id = row["image_id"]
588
+ patient_id = row["patient_id"]
589
+
590
+ try:
591
+ img_9ch = build_9ch_for_row(row, image_path_map, mean_9=mean_9, std_9=std_9)
592
+ except Exception as exc:
593
+ print(f" [{i+1}/{len(adjacency_df)}] SKIP {image_id}: {exc}")
594
+ continue
595
+
596
+ tensor = torch.from_numpy(img_9ch).permute(2, 0, 1).unsqueeze(0).to(device)
597
+
598
+ fold_logits = []
599
+ fold_cams = []
600
+ if GENERATE_HEATMAPS:
601
+ for model, cam_obj in zip(models, gradcam_objects):
602
+ logits, cam = cam_obj.generate(tensor, class_idx=0)
603
+ fold_logits.append(logits)
604
+ fold_cams.append(cam)
605
+ else:
606
+ with torch.no_grad():
607
+ for model in models:
608
+ logits = model(tensor).squeeze(0).detach().cpu().numpy().astype(np.float32)
609
+ fold_logits.append(logits)
610
+
611
+ mean_logits = np.mean(np.stack(fold_logits, axis=0), axis=0)
612
+ raw_probs = sigmoid_np(mean_logits)
613
+ cal_probs = apply_calibration(mean_logits, calib_cfg, iso_models)
614
+
615
+ probs_dict = {name: float(cal_probs[j]) for j, name in enumerate(SUBTYPES)}
616
+
617
+ # Save a per-slice visualization image (windowed center slice) for report artifacts.
618
+ preview_path = reports_dir / f"{image_id}_preview.png"
619
+ heatmap_path = reports_dir / f"{image_id}_gradcam.png"
620
+ try:
621
+ center_rgb = load_single_dicom_3ch(Path(row["dcm_path"]), size=IMG_SIZE)
622
+ center_rgb_u8 = (np.clip(center_rgb, 0.0, 1.0) * 255.0).astype(np.uint8)
623
+ cv2.imwrite(str(preview_path), cv2.cvtColor(center_rgb_u8, cv2.COLOR_RGB2BGR))
624
+ if GENERATE_HEATMAPS:
625
+ if fold_cams:
626
+ mean_cam = np.mean(np.stack(fold_cams, axis=0), axis=0)
627
+ else:
628
+ mean_cam = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
629
+ overlay_rgb = make_overlay(center_rgb_u8, mean_cam, alpha=0.45)
630
+ cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
631
+ report_image_path = str(preview_path)
632
+ report_heatmap_path = str(heatmap_path) if GENERATE_HEATMAPS else ""
633
+ except Exception:
634
+ report_image_path = ""
635
+ report_heatmap_path = ""
636
+
637
+ true_any = label_map.get(image_id)
638
+ rep = build_slice_report(
639
+ image_id=image_id,
640
+ patient_id=patient_id,
641
+ probs=probs_dict,
642
+ calib_cfg=calib_cfg,
643
+ threshold=threshold,
644
+ loaded_folds=loaded_folds,
645
+ report_image_path=report_image_path,
646
+ heatmap_path=report_heatmap_path,
647
+ true_label=int(true_any) if true_any is not None else None,
648
+ )
649
+
650
+ report_path = reports_dir / f"{image_id}_report.json"
651
+ with open(report_path, "w", encoding="utf-8") as f:
652
+ json.dump(rep, f, separators=(",", ":"), ensure_ascii=True)
653
+
654
+ slice_rows.append(
655
+ {
656
+ "image_id": image_id,
657
+ "patient_id": patient_id,
658
+ "true_any": int(true_any) if true_any is not None else "",
659
+ "pred_any": int(probs_dict["any"] >= threshold),
660
+ "cal_any": round(probs_dict["any"], 6),
661
+ "raw_any": round(float(raw_probs[0]), 6),
662
+ **{f"cal_{name}": round(float(probs_dict[name]), 6) for name in SUBTYPES[1:]},
663
+ "confidence_band": rep["prediction"]["confidence_band"],
664
+ "triage_action": rep["triage"]["action"],
665
+ "urgency": rep["triage"]["urgency"],
666
+ }
667
+ )
668
+
669
+ report_summary_rows.append(
670
+ {
671
+ "image_id": image_id,
672
+ "true_label": int(true_any) if true_any is not None else "",
673
+ "screening_outcome": rep["prediction"]["screening_outcome"],
674
+ "raw_prob": round(float(raw_probs[0]), 6),
675
+ "cal_prob": round(float(probs_dict["any"]), 6),
676
+ "confidence_band": rep["prediction"]["confidence_band"],
677
+ "triage_action": rep["triage"]["action"],
678
+ "urgency": rep["triage"]["urgency"],
679
+ "image_path": report_image_path,
680
+ "heatmap_path": report_heatmap_path,
681
+ }
682
+ )
683
+
684
+ patient_probs.setdefault(patient_id, []).append(probs_dict["any"])
685
+
686
+ status = "[+] POS" if probs_dict["any"] >= threshold else "[-] NEG"
687
+ print(
688
+ f" [{i+1}/{len(adjacency_df)}] {image_id} → {status} "
689
+ f"cal_any={probs_dict['any']:.4f}"
690
+ )
691
+
692
+ if not slice_rows:
693
+ print("\nERROR: No slices were processed successfully.")
694
+ return
695
+
696
+ slice_df = pd.DataFrame(slice_rows)
697
+ slice_csv_path = OUTPUT_DIR / "slice_predictions.csv"
698
+ slice_df.to_csv(slice_csv_path, index=False)
699
+
700
+ report_summary_df = pd.DataFrame(report_summary_rows)
701
+ report_summary_csv_path = OUTPUT_DIR / "report_summary.csv"
702
+ report_summary_df.to_csv(report_summary_csv_path, index=False)
703
+
704
+ patient_rows = []
705
+ for pid, vals in patient_probs.items():
706
+ arr = np.asarray(vals, dtype=np.float32)
707
+ agg_prob = patient_aggregate(arr, PATIENT_AGG_METHOD, PATIENT_TOPK)
708
+ patient_rows.append(
709
+ {
710
+ "patient_id": pid,
711
+ "n_slices": int(len(arr)),
712
+ "agg_method": PATIENT_AGG_METHOD,
713
+ "agg_any_probability": round(float(agg_prob), 6),
714
+ "pred_any": int(agg_prob >= threshold),
715
+ }
716
+ )
717
+
718
+ patient_df = pd.DataFrame(patient_rows)
719
+ patient_csv_path = OUTPUT_DIR / "patient_predictions.csv"
720
+ patient_df.to_csv(patient_csv_path, index=False)
721
+
722
+ for cam_obj in gradcam_objects:
723
+ cam_obj.remove()
724
+
725
+ n_pos = int((slice_df["pred_any"] == 1).sum())
726
+ n_total = len(slice_df)
727
+ n_urgent = int((slice_df["urgency"] == "URGENT").sum())
728
+
729
+ print(f"\n{'═' * 72}")
730
+ print(" INFERENCE COMPLETE")
731
+ print(f"{'═' * 72}")
732
+ print(f" Slices processed : {n_total}")
733
+ print(f" Positive slices : {n_pos}")
734
+ print(f" Urgent escalations : {n_urgent}")
735
+ print(f" Patients processed : {len(patient_df)}")
736
+ print("\n Outputs:")
737
+ print(f" JSON reports : {reports_dir}")
738
+ print(f" Report images : {reports_dir}")
739
+ print(f" Report summary : {report_summary_csv_path}")
740
+ print(f" Slice CSV : {slice_csv_path}")
741
+ print(f" Patient CSV : {patient_csv_path}")
742
+ print(f"{'═' * 72}")
743
+
744
+
745
+ if __name__ == "__main__":
746
+ main()