Harshit Ghosh commited on
Commit ·
ea664f8
1
Parent(s): 410e48e
making changes for huggingface
Browse files- .gitignore +1 -1
- download_imp/INFERENCE_USAGE.md +107 -0
- download_imp/calibration_params.json +17 -0
- download_imp/manifest.csv +0 -0
- download_imp/normalization_stats.json +38 -0
- download_imp/run_inference.py +746 -0
.gitignore
CHANGED
|
@@ -45,7 +45,7 @@ download_imp/*.pt
|
|
| 45 |
download_imp/*.pkl
|
| 46 |
download_imp/*.onnx
|
| 47 |
|
| 48 |
-
download_imp/*
|
| 49 |
# Local downloaded artifacts
|
| 50 |
download/
|
| 51 |
# Local downloaded data
|
|
|
|
| 45 |
download_imp/*.pkl
|
| 46 |
download_imp/*.onnx
|
| 47 |
|
| 48 |
+
# download_imp/*
|
| 49 |
# Local downloaded artifacts
|
| 50 |
download/
|
| 51 |
# Local downloaded data
|
download_imp/INFERENCE_USAGE.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Improved Model Inference — Usage Guide (`download_imp`)
|
| 2 |
+
|
| 3 |
+
## What this runner does
|
| 4 |
+
|
| 5 |
+
`run_inference.py` runs the **latest improved model** trained from the notebooks in `improvement/`:
|
| 6 |
+
|
| 7 |
+
- EfficientNet-B4 backbone (`tf_efficientnet_b4`)
|
| 8 |
+
- 2.5D input (prev + center + next slices → 9 channels)
|
| 9 |
+
- 6 outputs (`any` + 5 hemorrhage subtypes)
|
| 10 |
+
- 5-fold ensemble (`best_model_fold0..4.pth`)
|
| 11 |
+
- Saved calibration (`isotonic`/`temperature`) from `calibration_params.json`
|
| 12 |
+
|
| 13 |
+
Outputs:
|
| 14 |
+
|
| 15 |
+
- Per-slice JSON report (`outputs/reports/*.json`)
|
| 16 |
+
- Slice-level CSV (`outputs/slice_predictions.csv`)
|
| 17 |
+
- Patient-level CSV (`outputs/patient_predictions.csv`)
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Required files (already in `download_imp/`)
|
| 22 |
+
|
| 23 |
+
- `best_model_fold0.pth`
|
| 24 |
+
- `best_model_fold1.pth`
|
| 25 |
+
- `best_model_fold2.pth`
|
| 26 |
+
- `best_model_fold3.pth`
|
| 27 |
+
- `best_model_fold4.pth`
|
| 28 |
+
- `calibration_params.json`
|
| 29 |
+
- `isotonic_models.pkl`
|
| 30 |
+
- `normalization_stats.json`
|
| 31 |
+
- `manifest.csv` (optional at inference time; used only for `true_any` if IDs match)
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Python package requirements
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
pip install -r requirements.txt
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Notes:
|
| 42 |
+
|
| 43 |
+
- `timm` is required for `tf_efficientnet_b4` model construction.
|
| 44 |
+
- `scikit-learn` is needed to deserialize and use `isotonic_models.pkl`.
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Folder setup
|
| 49 |
+
|
| 50 |
+
Create this folder and place DICOM files there:
|
| 51 |
+
|
| 52 |
+
```text
|
| 53 |
+
download_imp/
|
| 54 |
+
├── run_inference.py
|
| 55 |
+
├── run_interface.py
|
| 56 |
+
├── best_model_fold0.pth
|
| 57 |
+
├── best_model_fold1.pth
|
| 58 |
+
├── best_model_fold2.pth
|
| 59 |
+
├── best_model_fold3.pth
|
| 60 |
+
├── best_model_fold4.pth
|
| 61 |
+
├── calibration_params.json
|
| 62 |
+
├── isotonic_models.pkl
|
| 63 |
+
├── normalization_stats.json
|
| 64 |
+
├── manifest.csv
|
| 65 |
+
└── dicom_inputs/
|
| 66 |
+
├── ID_xxx1.dcm
|
| 67 |
+
├── ID_xxx2.dcm
|
| 68 |
+
└── ...
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Run commands
|
| 74 |
+
|
| 75 |
+
From workspace root:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
cd download_imp
|
| 79 |
+
python run_inference.py
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
or (same thing):
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python run_interface.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Important behavior
|
| 91 |
+
|
| 92 |
+
- No CLI arguments; all settings are at top of `run_inference.py` (`CONFIG` section).
|
| 93 |
+
- `FOLD_SELECTION` controls checkpoint selection:
|
| 94 |
+
- `"ensemble"` = use all available folds and average logits
|
| 95 |
+
- `0..4` = use one specific fold only
|
| 96 |
+
- If `best_method` is `isotonic`, the runner uses `isotonic_models.pkl`.
|
| 97 |
+
- Missing prev/next slice in a series is handled exactly like training cache logic: neighbor falls back to center slice.
|
| 98 |
+
- Decision threshold defaults to `threshold_at_spec90` from `calibration_params.json` unless overridden in config.
|
| 99 |
+
|
| 100 |
+
---
|
| 101 |
+
|
| 102 |
+
## Recommended production checklist
|
| 103 |
+
|
| 104 |
+
1. Keep all fold checkpoints and calibration files in the same `download_imp/` directory.
|
| 105 |
+
2. Verify DICOMs are non-contrast head CT slices before inference.
|
| 106 |
+
3. Run once on a small sample and review `slice_predictions.csv` and JSON reports.
|
| 107 |
+
4. Have radiologist review all flagged and uncertain cases.
|
download_imp/calibration_params.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_method": "isotonic",
|
| 3 |
+
"temperature": 0.7681182881076687,
|
| 4 |
+
"ece_raw": 0.029524699832706974,
|
| 5 |
+
"ece_temp": 0.008160804909196835,
|
| 6 |
+
"ece_isotonic": 6.68228018673829e-18,
|
| 7 |
+
"brier_raw": 0.05208174680122169,
|
| 8 |
+
"brier_temp": 0.0514404718775377,
|
| 9 |
+
"brier_isotonic": 0.050937655679539166,
|
| 10 |
+
"triage_high_thresh": 0.7,
|
| 11 |
+
"triage_low_thresh": 0.3,
|
| 12 |
+
"oof_auc_any": 0.9501744154042578,
|
| 13 |
+
"sensitivity_at_spec90": 0.8556471787269526,
|
| 14 |
+
"specificity_at_sens95": 0.7441809977204707,
|
| 15 |
+
"threshold_at_spec90": 0.16216216216216217,
|
| 16 |
+
"threshold_at_sens95": 0.04074505238649592
|
| 17 |
+
}
|
download_imp/manifest.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
download_imp/normalization_stats.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mean_3ch": [
|
| 3 |
+
0.162136,
|
| 4 |
+
0.141483,
|
| 5 |
+
0.183675
|
| 6 |
+
],
|
| 7 |
+
"std_3ch": [
|
| 8 |
+
0.312067,
|
| 9 |
+
0.283885,
|
| 10 |
+
0.305968
|
| 11 |
+
],
|
| 12 |
+
"mean_9ch": [
|
| 13 |
+
0.162136,
|
| 14 |
+
0.141483,
|
| 15 |
+
0.183675,
|
| 16 |
+
0.162136,
|
| 17 |
+
0.141483,
|
| 18 |
+
0.183675,
|
| 19 |
+
0.162136,
|
| 20 |
+
0.141483,
|
| 21 |
+
0.183675
|
| 22 |
+
],
|
| 23 |
+
"std_9ch": [
|
| 24 |
+
0.312067,
|
| 25 |
+
0.283885,
|
| 26 |
+
0.305968,
|
| 27 |
+
0.312067,
|
| 28 |
+
0.283885,
|
| 29 |
+
0.305968,
|
| 30 |
+
0.312067,
|
| 31 |
+
0.283885,
|
| 32 |
+
0.305968
|
| 33 |
+
],
|
| 34 |
+
"n_sample_images": 5000,
|
| 35 |
+
"n_pixels": 722000000,
|
| 36 |
+
"img_size": 380,
|
| 37 |
+
"note": "Stats computed on windowed [0,1] images. Pre-normalization applied in cache."
|
| 38 |
+
}
|
download_imp/run_inference.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Standalone Inference Script — Improved ICH Screening (2.5D, 5-fold Ensemble)
|
| 3 |
+
=============================================================================
|
| 4 |
+
Reads raw DICOM CT brain slices, reproduces improved preprocessing from the
|
| 5 |
+
improvement notebooks, runs 5-fold EfficientNet-B4 ensemble inference, applies
|
| 6 |
+
saved calibration, and generates:
|
| 7 |
+
• Per-image JSON reports (fixed schema)
|
| 8 |
+
• Slice-level CSV summary
|
| 9 |
+
• Patient-level CSV summary
|
| 10 |
+
|
| 11 |
+
No command-line arguments — all paths are configured in the CONFIG section.
|
| 12 |
+
|
| 13 |
+
Requirements:
|
| 14 |
+
pip install torch timm pydicom opencv-python-headless numpy pandas scikit-learn
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python run_inference.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import datetime
|
| 21 |
+
import json
|
| 22 |
+
import pickle
|
| 23 |
+
import warnings
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Dict, List, Optional, Tuple
|
| 26 |
+
from zoneinfo import ZoneInfo
|
| 27 |
+
|
| 28 |
+
import cv2
|
| 29 |
+
import numpy as np
|
| 30 |
+
import pandas as pd
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import timm
|
| 34 |
+
|
| 35 |
+
# Try importing pydicom — needed for DICOM input mode
|
| 36 |
+
has_pydicom = False
|
| 37 |
+
pydicom = None
|
| 38 |
+
try:
|
| 39 |
+
import pydicom
|
| 40 |
+
import pydicom.multival
|
| 41 |
+
|
| 42 |
+
# Some anonymized datasets contain non-standard UID strings like "ID_xxx".
|
| 43 |
+
# Ignore only this known noisy warning from pydicom.
|
| 44 |
+
warnings.filterwarnings(
|
| 45 |
+
"ignore",
|
| 46 |
+
message=r"Invalid value for VR UI:",
|
| 47 |
+
category=UserWarning,
|
| 48 |
+
module=r"pydicom\.valuerep",
|
| 49 |
+
)
|
| 50 |
+
has_pydicom = True
|
| 51 |
+
except ImportError:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 55 |
+
# CONFIG — edit these paths before running
|
| 56 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 57 |
+
|
| 58 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 59 |
+
|
| 60 |
+
# Model artifacts (required)
|
| 61 |
+
FOLD_MODEL_PATHS = [SCRIPT_DIR / f"best_model_fold{i}.pth" for i in range(5)]
|
| 62 |
+
CALIB_PARAMS_PATH = SCRIPT_DIR / "calibration_params.json"
|
| 63 |
+
ISOTONIC_MODELS_PATH = SCRIPT_DIR / "isotonic_models.pkl"
|
| 64 |
+
NORM_STATS_PATH = SCRIPT_DIR / "normalization_stats.json"
|
| 65 |
+
|
| 66 |
+
# Input — folder containing .dcm files
|
| 67 |
+
DICOM_INPUT_DIR = Path(r"D:\major8thsem\stage_2_test")
|
| 68 |
+
|
| 69 |
+
# Optional labels (only for quick validation against known RSNA IDs)
|
| 70 |
+
MANIFEST_PATH = SCRIPT_DIR / "manifest.csv"
|
| 71 |
+
|
| 72 |
+
# Output
|
| 73 |
+
OUTPUT_DIR = SCRIPT_DIR / "outputs"
|
| 74 |
+
|
| 75 |
+
# Architecture constants (must match training notebooks)
|
| 76 |
+
BACKBONE = "tf_efficientnet_b4"
|
| 77 |
+
IMG_SIZE = 380
|
| 78 |
+
IN_CHANNELS = 9
|
| 79 |
+
N_CLASSES = 6
|
| 80 |
+
DROPOUT = 0.4
|
| 81 |
+
DROP_PATH = 0.2
|
| 82 |
+
|
| 83 |
+
# Output / triage behavior
|
| 84 |
+
PATIENT_AGG_METHOD = "topk_mean" # one of: max, mean, noisy_or, topk_mean
|
| 85 |
+
PATIENT_TOPK = 3
|
| 86 |
+
DECISION_THRESHOLD = None # None -> use threshold_at_spec90 from calibration JSON
|
| 87 |
+
FOLD_SELECTION = "ensemble" # "ensemble" or an integer fold id: 0..4
|
| 88 |
+
GENERATE_HEATMAPS = True
|
| 89 |
+
|
| 90 |
+
WINDOWS = [
|
| 91 |
+
(40, 80), # brain
|
| 92 |
+
(75, 215), # subdural
|
| 93 |
+
(40, 380), # soft tissue / bone
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
SUBTYPES = [
|
| 97 |
+
"any",
|
| 98 |
+
"epidural",
|
| 99 |
+
"intraparenchymal",
|
| 100 |
+
"intraventricular",
|
| 101 |
+
"subarachnoid",
|
| 102 |
+
"subdural",
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
OUTCOME_POSITIVE = "Hemorrhage indicator detected"
|
| 106 |
+
OUTCOME_NEGATIVE = "No hemorrhage indicator detected"
|
| 107 |
+
|
| 108 |
+
BAND_LABELS = {
|
| 109 |
+
"HIGH": "High confidence",
|
| 110 |
+
"MEDIUM": "Moderate confidence",
|
| 111 |
+
"LOW": "Low confidence",
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
TRIAGE_ACTIONS = {
|
| 115 |
+
("POSITIVE", "HIGH"): "Urgent radiologist review recommended",
|
| 116 |
+
("POSITIVE", "MEDIUM"): "Prioritised radiologist review recommended",
|
| 117 |
+
("POSITIVE", "LOW"): "Radiologist review recommended — low confidence",
|
| 118 |
+
("NEGATIVE", "HIGH"): "Standard workflow — no urgent action",
|
| 119 |
+
("NEGATIVE", "MEDIUM"): "Standard workflow — manual review if clinically indicated",
|
| 120 |
+
("NEGATIVE", "LOW"): "Manual review recommended — model uncertainty high",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
DISCLAIMER = (
|
| 124 |
+
"This report is produced by an AI-assisted screening tool and does NOT "
|
| 125 |
+
"constitute a medical diagnosis. All screening findings must be reviewed "
|
| 126 |
+
"and confirmed by a qualified, licensed medical professional before any "
|
| 127 |
+
"clinical decision is made. The system is intended solely as a "
|
| 128 |
+
"decision-support aid in a screening workflow and is not cleared for "
|
| 129 |
+
"standalone diagnostic use."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
IST = ZoneInfo("Asia/Kolkata")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 136 |
+
# DICOM + PREPROCESSING
|
| 137 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _to_scalar(val) -> float:
|
| 141 |
+
if has_pydicom and isinstance(val, (list, pydicom.multival.MultiValue)):
|
| 142 |
+
return float(val[0])
|
| 143 |
+
return float(val)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def apply_window(img_hu: np.ndarray, wc: float, ww: float) -> np.ndarray:
|
| 147 |
+
lo = wc - ww / 2
|
| 148 |
+
hi = wc + ww / 2
|
| 149 |
+
return np.clip((img_hu - lo) / (hi - lo), 0.0, 1.0)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_single_dicom_3ch(dcm_path: Path, size: int = IMG_SIZE) -> np.ndarray:
|
| 153 |
+
if not has_pydicom or pydicom is None:
|
| 154 |
+
raise RuntimeError("pydicom is not installed. Run: pip install pydicom")
|
| 155 |
+
dcm = pydicom.dcmread(str(dcm_path))
|
| 156 |
+
img = dcm.pixel_array.astype(np.float32)
|
| 157 |
+
|
| 158 |
+
slope = _to_scalar(getattr(dcm, "RescaleSlope", 1))
|
| 159 |
+
inter = _to_scalar(getattr(dcm, "RescaleIntercept", 0))
|
| 160 |
+
img = img * slope + inter
|
| 161 |
+
|
| 162 |
+
channels = []
|
| 163 |
+
for wc, ww in WINDOWS:
|
| 164 |
+
ch = apply_window(img, wc, ww)
|
| 165 |
+
ch = cv2.resize(ch, (size, size), interpolation=cv2.INTER_AREA)
|
| 166 |
+
channels.append(ch)
|
| 167 |
+
|
| 168 |
+
return np.stack(channels, axis=-1).astype(np.float32) # (H, W, 3) in [0,1]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def build_adjacency(dicom_dir: Path) -> pd.DataFrame:
|
| 172 |
+
if not has_pydicom or pydicom is None:
|
| 173 |
+
raise RuntimeError("pydicom is not installed. Run: pip install pydicom")
|
| 174 |
+
records: List[dict] = []
|
| 175 |
+
for dcm_path in sorted(dicom_dir.glob("*.dcm")):
|
| 176 |
+
image_id = dcm_path.stem
|
| 177 |
+
try:
|
| 178 |
+
dcm = pydicom.dcmread(str(dcm_path), stop_before_pixels=True)
|
| 179 |
+
patient_id = str(getattr(dcm, "PatientID", "UNKNOWN"))
|
| 180 |
+
series_uid = str(getattr(dcm, "SeriesInstanceUID", "UNKNOWN_SERIES"))
|
| 181 |
+
|
| 182 |
+
ipp = getattr(dcm, "ImagePositionPatient", None)
|
| 183 |
+
if ipp is not None and len(ipp) >= 3:
|
| 184 |
+
z_pos = float(ipp[2])
|
| 185 |
+
else:
|
| 186 |
+
z_pos = float(getattr(dcm, "SliceLocation", 0.0))
|
| 187 |
+
except Exception:
|
| 188 |
+
patient_id = "UNKNOWN"
|
| 189 |
+
series_uid = "UNKNOWN_SERIES"
|
| 190 |
+
z_pos = 0.0
|
| 191 |
+
|
| 192 |
+
records.append(
|
| 193 |
+
{
|
| 194 |
+
"image_id": image_id,
|
| 195 |
+
"patient_id": patient_id,
|
| 196 |
+
"series_uid": series_uid,
|
| 197 |
+
"z_pos": z_pos,
|
| 198 |
+
"dcm_path": str(dcm_path),
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if not records:
|
| 203 |
+
return pd.DataFrame(columns=["image_id", "patient_id", "series_uid", "z_pos", "dcm_path", "prev_image_id", "next_image_id"])
|
| 204 |
+
|
| 205 |
+
df = pd.DataFrame(records)
|
| 206 |
+
df = df.sort_values(["patient_id", "series_uid", "z_pos"]).reset_index(drop=True)
|
| 207 |
+
df["prev_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(1)
|
| 208 |
+
df["next_image_id"] = df.groupby(["patient_id", "series_uid"])["image_id"].shift(-1)
|
| 209 |
+
return df
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def build_9ch_for_row(row: pd.Series, image_path_map: Dict[str, Path], mean_9: np.ndarray, std_9: np.ndarray) -> np.ndarray:
|
| 213 |
+
center_id = row["image_id"]
|
| 214 |
+
prev_id = row["prev_image_id"] if pd.notna(row.get("prev_image_id")) else None
|
| 215 |
+
next_id = row["next_image_id"] if pd.notna(row.get("next_image_id")) else None
|
| 216 |
+
|
| 217 |
+
center_arr = load_single_dicom_3ch(image_path_map[center_id], size=IMG_SIZE)
|
| 218 |
+
|
| 219 |
+
if prev_id is not None and prev_id in image_path_map:
|
| 220 |
+
prev_arr = load_single_dicom_3ch(image_path_map[prev_id], size=IMG_SIZE)
|
| 221 |
+
else:
|
| 222 |
+
prev_arr = center_arr
|
| 223 |
+
|
| 224 |
+
if next_id is not None and next_id in image_path_map:
|
| 225 |
+
next_arr = load_single_dicom_3ch(image_path_map[next_id], size=IMG_SIZE)
|
| 226 |
+
else:
|
| 227 |
+
next_arr = center_arr
|
| 228 |
+
|
| 229 |
+
img_9ch = np.concatenate([prev_arr, center_arr, next_arr], axis=-1).astype(np.float32)
|
| 230 |
+
img_9ch = (img_9ch - mean_9.reshape(1, 1, -1)) / (std_9.reshape(1, 1, -1) + 1e-7)
|
| 231 |
+
return img_9ch
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 235 |
+
# MODEL + CALIBRATION
|
| 236 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def build_model(
|
| 240 |
+
backbone: str = BACKBONE,
|
| 241 |
+
in_ch: int = IN_CHANNELS,
|
| 242 |
+
n_cls: int = N_CLASSES,
|
| 243 |
+
dropout: float = DROPOUT,
|
| 244 |
+
drop_path: float = DROP_PATH,
|
| 245 |
+
) -> nn.Module:
|
| 246 |
+
model = timm.create_model(
|
| 247 |
+
backbone,
|
| 248 |
+
pretrained=False,
|
| 249 |
+
num_classes=0,
|
| 250 |
+
drop_rate=dropout,
|
| 251 |
+
drop_path_rate=drop_path,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
old_conv = model.conv_stem
|
| 255 |
+
new_conv = nn.Conv2d(
|
| 256 |
+
in_ch,
|
| 257 |
+
old_conv.out_channels,
|
| 258 |
+
kernel_size=old_conv.kernel_size,
|
| 259 |
+
stride=old_conv.stride,
|
| 260 |
+
padding=old_conv.padding,
|
| 261 |
+
bias=(old_conv.bias is not None),
|
| 262 |
+
)
|
| 263 |
+
k = max(in_ch // 3, 1)
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
new_conv.weight.copy_(old_conv.weight.repeat(1, k, 1, 1) / k)
|
| 266 |
+
if old_conv.bias is not None:
|
| 267 |
+
new_conv.bias.copy_(old_conv.bias)
|
| 268 |
+
model.conv_stem = new_conv
|
| 269 |
+
|
| 270 |
+
n_feat = model.num_features
|
| 271 |
+
model.classifier = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(n_feat, n_cls))
|
| 272 |
+
return model
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _find_gradcam_target_layer(model: nn.Module) -> nn.Module:
|
| 276 |
+
# Prefer the last semantic convolutional stage for EfficientNet-like models.
|
| 277 |
+
if hasattr(model, "conv_head") and isinstance(model.conv_head, nn.Module):
|
| 278 |
+
return model.conv_head
|
| 279 |
+
conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
|
| 280 |
+
if not conv_layers:
|
| 281 |
+
raise RuntimeError("No convolutional layer found for Grad-CAM target")
|
| 282 |
+
return conv_layers[-1]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class GradCAM:
|
| 286 |
+
def __init__(self, model: nn.Module):
|
| 287 |
+
self.model = model
|
| 288 |
+
self.activations = None
|
| 289 |
+
self.gradients = None
|
| 290 |
+
target = _find_gradcam_target_layer(model)
|
| 291 |
+
self._fh = target.register_forward_hook(self._forward_hook)
|
| 292 |
+
self._bh = target.register_full_backward_hook(self._backward_hook)
|
| 293 |
+
|
| 294 |
+
def _forward_hook(self, _module, _inputs, output):
|
| 295 |
+
self.activations = output
|
| 296 |
+
|
| 297 |
+
def _backward_hook(self, _module, _grad_input, grad_output):
|
| 298 |
+
self.gradients = grad_output[0]
|
| 299 |
+
|
| 300 |
+
def remove(self):
|
| 301 |
+
self._fh.remove()
|
| 302 |
+
self._bh.remove()
|
| 303 |
+
|
| 304 |
+
def generate(self, input_tensor: torch.Tensor, class_idx: int = 0) -> Tuple[np.ndarray, np.ndarray]:
|
| 305 |
+
self.model.zero_grad(set_to_none=True)
|
| 306 |
+
use_amp = bool(input_tensor.is_cuda)
|
| 307 |
+
with torch.enable_grad():
|
| 308 |
+
with torch.cuda.amp.autocast(enabled=use_amp):
|
| 309 |
+
output = self.model(input_tensor)
|
| 310 |
+
target = output[:, class_idx].sum()
|
| 311 |
+
target.backward()
|
| 312 |
+
|
| 313 |
+
logits = output.detach().cpu().numpy().astype(np.float32)
|
| 314 |
+
if logits.ndim == 1:
|
| 315 |
+
logits = logits[None, :]
|
| 316 |
+
|
| 317 |
+
if self.activations is None or self.gradients is None:
|
| 318 |
+
cam = np.zeros((logits.shape[0], IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| 319 |
+
return (logits[0], cam[0]) if logits.shape[0] == 1 else (logits, cam)
|
| 320 |
+
|
| 321 |
+
acts = self.activations.detach()
|
| 322 |
+
grads = self.gradients.detach()
|
| 323 |
+
weights = grads.mean(dim=(2, 3), keepdim=True)
|
| 324 |
+
cam = torch.relu((weights * acts).sum(dim=1)).cpu().numpy().astype(np.float32)
|
| 325 |
+
if cam.ndim == 2:
|
| 326 |
+
cam = cam[None, ...]
|
| 327 |
+
|
| 328 |
+
for idx in range(cam.shape[0]):
|
| 329 |
+
if cam[idx].size == 0 or float(cam[idx].max()) <= 0.0:
|
| 330 |
+
cam[idx] = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| 331 |
+
else:
|
| 332 |
+
cam[idx] = (cam[idx] - cam[idx].min()) / (cam[idx].max() - cam[idx].min() + 1e-8)
|
| 333 |
+
|
| 334 |
+
return (logits[0], cam[0]) if logits.shape[0] == 1 else (logits, cam)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def make_overlay(orig_rgb_u8: np.ndarray, cam: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| 338 |
+
cam_r = cv2.resize(cam, (orig_rgb_u8.shape[1], orig_rgb_u8.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 339 |
+
heat_u8 = np.uint8(np.clip(cam_r, 0.0, 1.0) * 255.0)
|
| 340 |
+
heat_bgr = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)
|
| 341 |
+
heat_rgb = cv2.cvtColor(heat_bgr, cv2.COLOR_BGR2RGB)
|
| 342 |
+
return (alpha * heat_rgb + (1 - alpha) * orig_rgb_u8).astype(np.uint8)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def load_models(device: str, fold_selection=None) -> Tuple[List[nn.Module], List[int]]:
|
| 346 |
+
models = []
|
| 347 |
+
loaded_folds: List[int] = []
|
| 348 |
+
|
| 349 |
+
if fold_selection is None:
|
| 350 |
+
fold_selection = FOLD_SELECTION
|
| 351 |
+
|
| 352 |
+
if isinstance(fold_selection, str) and fold_selection.lower() == "ensemble":
|
| 353 |
+
fold_indices = list(range(len(FOLD_MODEL_PATHS)))
|
| 354 |
+
elif isinstance(fold_selection, int):
|
| 355 |
+
fold_indices = [fold_selection]
|
| 356 |
+
elif isinstance(fold_selection, str) and fold_selection.isdigit():
|
| 357 |
+
fold_indices = [int(fold_selection)]
|
| 358 |
+
else:
|
| 359 |
+
raise ValueError('FOLD_SELECTION must be "ensemble" or an integer fold id (0..4).')
|
| 360 |
+
|
| 361 |
+
for fold_idx in fold_indices:
|
| 362 |
+
if fold_idx < 0 or fold_idx >= len(FOLD_MODEL_PATHS):
|
| 363 |
+
print(f" ⚠ Invalid fold index: {fold_idx} (skipping)")
|
| 364 |
+
continue
|
| 365 |
+
path = FOLD_MODEL_PATHS[fold_idx]
|
| 366 |
+
if not path.exists():
|
| 367 |
+
print(f" ⚠ Missing fold checkpoint: {path.name} (skipping)")
|
| 368 |
+
continue
|
| 369 |
+
model = build_model()
|
| 370 |
+
state = torch.load(str(path), map_location=device)
|
| 371 |
+
model.load_state_dict(state, strict=True)
|
| 372 |
+
model = model.to(device)
|
| 373 |
+
model.eval()
|
| 374 |
+
models.append(model)
|
| 375 |
+
loaded_folds.append(fold_idx)
|
| 376 |
+
return models, loaded_folds
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def sigmoid_np(x: np.ndarray) -> np.ndarray:
|
| 380 |
+
return 1.0 / (1.0 + np.exp(-x))
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def apply_calibration(raw_logits: np.ndarray, calib_cfg: dict, iso_models) -> np.ndarray:
|
| 384 |
+
best_method = calib_cfg.get("best_method", "temperature")
|
| 385 |
+
temperature = float(calib_cfg.get("temperature", 1.0))
|
| 386 |
+
|
| 387 |
+
if best_method == "isotonic" and iso_models is not None:
|
| 388 |
+
raw_probs = sigmoid_np(raw_logits)
|
| 389 |
+
cal_probs = np.zeros_like(raw_probs, dtype=np.float32)
|
| 390 |
+
for i, subtype in enumerate(SUBTYPES):
|
| 391 |
+
model_i = None
|
| 392 |
+
if isinstance(iso_models, dict):
|
| 393 |
+
model_i = iso_models.get(subtype)
|
| 394 |
+
if model_i is None:
|
| 395 |
+
model_i = iso_models.get(i)
|
| 396 |
+
elif isinstance(iso_models, (list, tuple)) and i < len(iso_models):
|
| 397 |
+
model_i = iso_models[i]
|
| 398 |
+
|
| 399 |
+
if model_i is not None:
|
| 400 |
+
cal_probs[i] = float(np.clip(model_i.predict([raw_probs[i]])[0], 0.0, 1.0))
|
| 401 |
+
else:
|
| 402 |
+
cal_probs[i] = float(raw_probs[i])
|
| 403 |
+
return cal_probs
|
| 404 |
+
|
| 405 |
+
return sigmoid_np(raw_logits / max(temperature, 1e-6)).astype(np.float32)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def patient_aggregate(values: np.ndarray, method: str, topk: int) -> float:
|
| 409 |
+
if len(values) == 0:
|
| 410 |
+
return 0.0
|
| 411 |
+
if method == "max":
|
| 412 |
+
return float(np.max(values))
|
| 413 |
+
if method == "mean":
|
| 414 |
+
return float(np.mean(values))
|
| 415 |
+
if method == "noisy_or":
|
| 416 |
+
return float(1.0 - np.prod(1.0 - np.clip(values, 0.0, 1.0)))
|
| 417 |
+
if method == "topk_mean":
|
| 418 |
+
k = min(max(int(topk), 1), len(values))
|
| 419 |
+
top_vals = np.sort(values)[-k:]
|
| 420 |
+
return float(np.mean(top_vals))
|
| 421 |
+
raise ValueError(f"Unknown PATIENT_AGG_METHOD: {method}")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 425 |
+
# REPORT HELPERS
|
| 426 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def build_slice_report(
|
| 430 |
+
image_id: str,
|
| 431 |
+
patient_id: str,
|
| 432 |
+
probs: Dict[str, float],
|
| 433 |
+
calib_cfg: dict,
|
| 434 |
+
threshold: float,
|
| 435 |
+
loaded_folds: List[int],
|
| 436 |
+
report_image_path: Optional[str] = None,
|
| 437 |
+
heatmap_path: Optional[str] = None,
|
| 438 |
+
true_label: Optional[int] = None,
|
| 439 |
+
) -> dict:
|
| 440 |
+
cal_any = probs["any"]
|
| 441 |
+
high_thr = float(calib_cfg.get("triage_high_thresh", 0.7))
|
| 442 |
+
low_thr = float(calib_cfg.get("triage_low_thresh", 0.3))
|
| 443 |
+
|
| 444 |
+
if cal_any >= high_thr:
|
| 445 |
+
band = "HIGH"
|
| 446 |
+
elif cal_any >= low_thr:
|
| 447 |
+
band = "MEDIUM"
|
| 448 |
+
else:
|
| 449 |
+
band = "LOW"
|
| 450 |
+
|
| 451 |
+
is_positive = cal_any >= threshold
|
| 452 |
+
outcome_key = "POSITIVE" if is_positive else "NEGATIVE"
|
| 453 |
+
|
| 454 |
+
now_ist = datetime.datetime.now(IST)
|
| 455 |
+
report = {
|
| 456 |
+
"report_id": f"RPT_{now_ist.strftime('%Y%m%d_%H%M%S')}_{image_id[-8:]}",
|
| 457 |
+
"generated_at": now_ist.isoformat(),
|
| 458 |
+
"image_id": image_id,
|
| 459 |
+
"patient_id": patient_id,
|
| 460 |
+
"ground_truth_any": int(true_label) if true_label is not None else "N/A",
|
| 461 |
+
"screening_module": {
|
| 462 |
+
"version": "2.0",
|
| 463 |
+
"architecture": BACKBONE,
|
| 464 |
+
"input_type": "2.5D (9ch: prev+center+next)",
|
| 465 |
+
"ensemble": "ensemble" if len(loaded_folds) > 1 else "single-fold",
|
| 466 |
+
"folds_used": loaded_folds,
|
| 467 |
+
"calibration_method": calib_cfg.get("best_method", "temperature"),
|
| 468 |
+
},
|
| 469 |
+
"prediction": {
|
| 470 |
+
"screening_outcome": OUTCOME_POSITIVE if is_positive else OUTCOME_NEGATIVE,
|
| 471 |
+
"decision_threshold_any": round(float(threshold), 6),
|
| 472 |
+
"confidence_band": band,
|
| 473 |
+
"confidence_band_label": BAND_LABELS[band],
|
| 474 |
+
**{f"calibrated_prob_{k}": round(float(v), 6) for k, v in probs.items()},
|
| 475 |
+
},
|
| 476 |
+
"triage": {
|
| 477 |
+
"action": TRIAGE_ACTIONS[(outcome_key, band)],
|
| 478 |
+
"urgency": "URGENT" if (is_positive and band == "HIGH") else "STANDARD",
|
| 479 |
+
},
|
| 480 |
+
"disclaimer": DISCLAIMER,
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
if report_image_path or heatmap_path:
|
| 484 |
+
report["explainability"] = {
|
| 485 |
+
"method": "Gradient-weighted Class Activation Mapping (Grad-CAM)",
|
| 486 |
+
"image_path": report_image_path,
|
| 487 |
+
"heatmap_path": heatmap_path,
|
| 488 |
+
"note": (
|
| 489 |
+
"Highlighted regions indicate areas with greatest influence on the "
|
| 490 |
+
"screening decision. These are not confirmed anatomical findings."
|
| 491 |
+
),
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
return report
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 498 |
+
# MAIN
|
| 499 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def main():
|
| 503 |
+
print("=" * 72)
|
| 504 |
+
print(" ICH SCREENING — Improved 2.5D Inference")
|
| 505 |
+
print("=" * 72)
|
| 506 |
+
|
| 507 |
+
if not has_pydicom:
|
| 508 |
+
print("ERROR: pydicom is not installed. Run: pip install pydicom")
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
if not DICOM_INPUT_DIR.exists():
|
| 512 |
+
print(f"ERROR: DICOM input folder not found: {DICOM_INPUT_DIR}")
|
| 513 |
+
print(" Create this folder and place .dcm files inside it.")
|
| 514 |
+
return
|
| 515 |
+
|
| 516 |
+
for path in [CALIB_PARAMS_PATH, NORM_STATS_PATH]:
|
| 517 |
+
if not path.exists():
|
| 518 |
+
print(f"ERROR: Required file missing: {path}")
|
| 519 |
+
return
|
| 520 |
+
|
| 521 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 522 |
+
print(f"\n Device : {device}")
|
| 523 |
+
|
| 524 |
+
with open(NORM_STATS_PATH, "r", encoding="utf-8") as f:
|
| 525 |
+
norm = json.load(f)
|
| 526 |
+
mean_9 = np.asarray(norm["mean_9ch"], dtype=np.float32)
|
| 527 |
+
std_9 = np.asarray(norm["std_9ch"], dtype=np.float32)
|
| 528 |
+
|
| 529 |
+
with open(CALIB_PARAMS_PATH, "r", encoding="utf-8") as f:
|
| 530 |
+
calib_cfg = json.load(f)
|
| 531 |
+
|
| 532 |
+
iso_models = None
|
| 533 |
+
if ISOTONIC_MODELS_PATH.exists():
|
| 534 |
+
with open(ISOTONIC_MODELS_PATH, "rb") as f:
|
| 535 |
+
iso_models = pickle.load(f)
|
| 536 |
+
|
| 537 |
+
threshold = (
|
| 538 |
+
float(DECISION_THRESHOLD)
|
| 539 |
+
if DECISION_THRESHOLD is not None
|
| 540 |
+
else float(calib_cfg.get("threshold_at_spec90", 0.5))
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
print(f" Backbone : {BACKBONE}")
|
| 544 |
+
print(f" Input : {IN_CHANNELS}ch @ {IMG_SIZE}x{IMG_SIZE}")
|
| 545 |
+
print(f" Calibration : {calib_cfg.get('best_method', 'temperature')}")
|
| 546 |
+
print(f" Decision threshold: {threshold:.6f}")
|
| 547 |
+
|
| 548 |
+
models, loaded_folds = load_models(device, fold_selection=FOLD_SELECTION)
|
| 549 |
+
if not models:
|
| 550 |
+
print("ERROR: No fold checkpoints could be loaded.")
|
| 551 |
+
return
|
| 552 |
+
print(f" Fold models loaded: {len(models)} (folds: {loaded_folds})")
|
| 553 |
+
gradcam_objects = [GradCAM(m) for m in models] if GENERATE_HEATMAPS else []
|
| 554 |
+
|
| 555 |
+
adjacency_df = build_adjacency(DICOM_INPUT_DIR)
|
| 556 |
+
if adjacency_df.empty:
|
| 557 |
+
print(f"ERROR: No .dcm files found in {DICOM_INPUT_DIR}")
|
| 558 |
+
return
|
| 559 |
+
|
| 560 |
+
image_path_map = {
|
| 561 |
+
Path(p).stem: Path(p)
|
| 562 |
+
for p in adjacency_df["dcm_path"].tolist()
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
label_map: Dict[str, int] = {}
|
| 566 |
+
if MANIFEST_PATH.exists():
|
| 567 |
+
try:
|
| 568 |
+
manifest = pd.read_csv(MANIFEST_PATH)
|
| 569 |
+
if "image_id" in manifest.columns and "any" in manifest.columns:
|
| 570 |
+
label_map = dict(zip(manifest["image_id"], manifest["any"]))
|
| 571 |
+
print(f" Manifest labels : loaded {len(label_map)} rows")
|
| 572 |
+
except Exception as exc:
|
| 573 |
+
print(f" ⚠ Manifest load skipped: {exc}")
|
| 574 |
+
|
| 575 |
+
reports_dir = OUTPUT_DIR / "reports"
|
| 576 |
+
reports_dir.mkdir(parents=True, exist_ok=True)
|
| 577 |
+
|
| 578 |
+
print(f"\n{'─' * 72}")
|
| 579 |
+
print(f" Processing {len(adjacency_df)} DICOM slices")
|
| 580 |
+
print(f"{'─' * 72}\n")
|
| 581 |
+
|
| 582 |
+
slice_rows = []
|
| 583 |
+
report_summary_rows = []
|
| 584 |
+
patient_probs: Dict[str, List[float]] = {}
|
| 585 |
+
|
| 586 |
+
for i, row in adjacency_df.iterrows():
|
| 587 |
+
image_id = row["image_id"]
|
| 588 |
+
patient_id = row["patient_id"]
|
| 589 |
+
|
| 590 |
+
try:
|
| 591 |
+
img_9ch = build_9ch_for_row(row, image_path_map, mean_9=mean_9, std_9=std_9)
|
| 592 |
+
except Exception as exc:
|
| 593 |
+
print(f" [{i+1}/{len(adjacency_df)}] SKIP {image_id}: {exc}")
|
| 594 |
+
continue
|
| 595 |
+
|
| 596 |
+
tensor = torch.from_numpy(img_9ch).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 597 |
+
|
| 598 |
+
fold_logits = []
|
| 599 |
+
fold_cams = []
|
| 600 |
+
if GENERATE_HEATMAPS:
|
| 601 |
+
for model, cam_obj in zip(models, gradcam_objects):
|
| 602 |
+
logits, cam = cam_obj.generate(tensor, class_idx=0)
|
| 603 |
+
fold_logits.append(logits)
|
| 604 |
+
fold_cams.append(cam)
|
| 605 |
+
else:
|
| 606 |
+
with torch.no_grad():
|
| 607 |
+
for model in models:
|
| 608 |
+
logits = model(tensor).squeeze(0).detach().cpu().numpy().astype(np.float32)
|
| 609 |
+
fold_logits.append(logits)
|
| 610 |
+
|
| 611 |
+
mean_logits = np.mean(np.stack(fold_logits, axis=0), axis=0)
|
| 612 |
+
raw_probs = sigmoid_np(mean_logits)
|
| 613 |
+
cal_probs = apply_calibration(mean_logits, calib_cfg, iso_models)
|
| 614 |
+
|
| 615 |
+
probs_dict = {name: float(cal_probs[j]) for j, name in enumerate(SUBTYPES)}
|
| 616 |
+
|
| 617 |
+
# Save a per-slice visualization image (windowed center slice) for report artifacts.
|
| 618 |
+
preview_path = reports_dir / f"{image_id}_preview.png"
|
| 619 |
+
heatmap_path = reports_dir / f"{image_id}_gradcam.png"
|
| 620 |
+
try:
|
| 621 |
+
center_rgb = load_single_dicom_3ch(Path(row["dcm_path"]), size=IMG_SIZE)
|
| 622 |
+
center_rgb_u8 = (np.clip(center_rgb, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 623 |
+
cv2.imwrite(str(preview_path), cv2.cvtColor(center_rgb_u8, cv2.COLOR_RGB2BGR))
|
| 624 |
+
if GENERATE_HEATMAPS:
|
| 625 |
+
if fold_cams:
|
| 626 |
+
mean_cam = np.mean(np.stack(fold_cams, axis=0), axis=0)
|
| 627 |
+
else:
|
| 628 |
+
mean_cam = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float32)
|
| 629 |
+
overlay_rgb = make_overlay(center_rgb_u8, mean_cam, alpha=0.45)
|
| 630 |
+
cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
|
| 631 |
+
report_image_path = str(preview_path)
|
| 632 |
+
report_heatmap_path = str(heatmap_path) if GENERATE_HEATMAPS else ""
|
| 633 |
+
except Exception:
|
| 634 |
+
report_image_path = ""
|
| 635 |
+
report_heatmap_path = ""
|
| 636 |
+
|
| 637 |
+
true_any = label_map.get(image_id)
|
| 638 |
+
rep = build_slice_report(
|
| 639 |
+
image_id=image_id,
|
| 640 |
+
patient_id=patient_id,
|
| 641 |
+
probs=probs_dict,
|
| 642 |
+
calib_cfg=calib_cfg,
|
| 643 |
+
threshold=threshold,
|
| 644 |
+
loaded_folds=loaded_folds,
|
| 645 |
+
report_image_path=report_image_path,
|
| 646 |
+
heatmap_path=report_heatmap_path,
|
| 647 |
+
true_label=int(true_any) if true_any is not None else None,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
report_path = reports_dir / f"{image_id}_report.json"
|
| 651 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 652 |
+
json.dump(rep, f, separators=(",", ":"), ensure_ascii=True)
|
| 653 |
+
|
| 654 |
+
slice_rows.append(
|
| 655 |
+
{
|
| 656 |
+
"image_id": image_id,
|
| 657 |
+
"patient_id": patient_id,
|
| 658 |
+
"true_any": int(true_any) if true_any is not None else "",
|
| 659 |
+
"pred_any": int(probs_dict["any"] >= threshold),
|
| 660 |
+
"cal_any": round(probs_dict["any"], 6),
|
| 661 |
+
"raw_any": round(float(raw_probs[0]), 6),
|
| 662 |
+
**{f"cal_{name}": round(float(probs_dict[name]), 6) for name in SUBTYPES[1:]},
|
| 663 |
+
"confidence_band": rep["prediction"]["confidence_band"],
|
| 664 |
+
"triage_action": rep["triage"]["action"],
|
| 665 |
+
"urgency": rep["triage"]["urgency"],
|
| 666 |
+
}
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
report_summary_rows.append(
|
| 670 |
+
{
|
| 671 |
+
"image_id": image_id,
|
| 672 |
+
"true_label": int(true_any) if true_any is not None else "",
|
| 673 |
+
"screening_outcome": rep["prediction"]["screening_outcome"],
|
| 674 |
+
"raw_prob": round(float(raw_probs[0]), 6),
|
| 675 |
+
"cal_prob": round(float(probs_dict["any"]), 6),
|
| 676 |
+
"confidence_band": rep["prediction"]["confidence_band"],
|
| 677 |
+
"triage_action": rep["triage"]["action"],
|
| 678 |
+
"urgency": rep["triage"]["urgency"],
|
| 679 |
+
"image_path": report_image_path,
|
| 680 |
+
"heatmap_path": report_heatmap_path,
|
| 681 |
+
}
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
patient_probs.setdefault(patient_id, []).append(probs_dict["any"])
|
| 685 |
+
|
| 686 |
+
status = "[+] POS" if probs_dict["any"] >= threshold else "[-] NEG"
|
| 687 |
+
print(
|
| 688 |
+
f" [{i+1}/{len(adjacency_df)}] {image_id} → {status} "
|
| 689 |
+
f"cal_any={probs_dict['any']:.4f}"
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if not slice_rows:
|
| 693 |
+
print("\nERROR: No slices were processed successfully.")
|
| 694 |
+
return
|
| 695 |
+
|
| 696 |
+
slice_df = pd.DataFrame(slice_rows)
|
| 697 |
+
slice_csv_path = OUTPUT_DIR / "slice_predictions.csv"
|
| 698 |
+
slice_df.to_csv(slice_csv_path, index=False)
|
| 699 |
+
|
| 700 |
+
report_summary_df = pd.DataFrame(report_summary_rows)
|
| 701 |
+
report_summary_csv_path = OUTPUT_DIR / "report_summary.csv"
|
| 702 |
+
report_summary_df.to_csv(report_summary_csv_path, index=False)
|
| 703 |
+
|
| 704 |
+
patient_rows = []
|
| 705 |
+
for pid, vals in patient_probs.items():
|
| 706 |
+
arr = np.asarray(vals, dtype=np.float32)
|
| 707 |
+
agg_prob = patient_aggregate(arr, PATIENT_AGG_METHOD, PATIENT_TOPK)
|
| 708 |
+
patient_rows.append(
|
| 709 |
+
{
|
| 710 |
+
"patient_id": pid,
|
| 711 |
+
"n_slices": int(len(arr)),
|
| 712 |
+
"agg_method": PATIENT_AGG_METHOD,
|
| 713 |
+
"agg_any_probability": round(float(agg_prob), 6),
|
| 714 |
+
"pred_any": int(agg_prob >= threshold),
|
| 715 |
+
}
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
patient_df = pd.DataFrame(patient_rows)
|
| 719 |
+
patient_csv_path = OUTPUT_DIR / "patient_predictions.csv"
|
| 720 |
+
patient_df.to_csv(patient_csv_path, index=False)
|
| 721 |
+
|
| 722 |
+
for cam_obj in gradcam_objects:
|
| 723 |
+
cam_obj.remove()
|
| 724 |
+
|
| 725 |
+
n_pos = int((slice_df["pred_any"] == 1).sum())
|
| 726 |
+
n_total = len(slice_df)
|
| 727 |
+
n_urgent = int((slice_df["urgency"] == "URGENT").sum())
|
| 728 |
+
|
| 729 |
+
print(f"\n{'═' * 72}")
|
| 730 |
+
print(" INFERENCE COMPLETE")
|
| 731 |
+
print(f"{'═' * 72}")
|
| 732 |
+
print(f" Slices processed : {n_total}")
|
| 733 |
+
print(f" Positive slices : {n_pos}")
|
| 734 |
+
print(f" Urgent escalations : {n_urgent}")
|
| 735 |
+
print(f" Patients processed : {len(patient_df)}")
|
| 736 |
+
print("\n Outputs:")
|
| 737 |
+
print(f" JSON reports : {reports_dir}")
|
| 738 |
+
print(f" Report images : {reports_dir}")
|
| 739 |
+
print(f" Report summary : {report_summary_csv_path}")
|
| 740 |
+
print(f" Slice CSV : {slice_csv_path}")
|
| 741 |
+
print(f" Patient CSV : {patient_csv_path}")
|
| 742 |
+
print(f"{'═' * 72}")
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
if __name__ == "__main__":
|
| 746 |
+
main()
|