Niko.Koutsoubis
commited on
Commit
·
a091733
1
Parent(s):
88d9d81
Add embedding extraction pipeline for federated learning
Browse files- EXTRACTION_README.md +214 -0
- extract-embeddings.py +986 -0
EXTRACTION_README.md
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sybil Embedding Extraction Pipeline
|
| 2 |
+
|
| 3 |
+
This script extracts 512-dimensional embeddings from chest CT DICOM scans using the Sybil lung cancer risk prediction model. It's designed for **federated learning** deployments where sites need to generate embeddings locally without sharing raw medical images.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- ✅ **Automatic Model Download**: Downloads Sybil model from HuggingFace automatically
|
| 8 |
+
- ✅ **Multi-GPU Support**: Process scans in parallel across multiple GPUs
|
| 9 |
+
- ✅ **Smart Filtering**: Automatically filters out localizer/scout scans
|
| 10 |
+
- ✅ **PID-Based Extraction**: Extract embeddings for specific patient cohorts
|
| 11 |
+
- ✅ **Checkpoint System**: Save progress every N scans to prevent data loss
|
| 12 |
+
- ✅ **Timepoint Detection**: Automatically detects T0, T1, T2... from scan dates
|
| 13 |
+
- ✅ **Directory Caching**: Cache directory scans for 100x faster reruns
|
| 14 |
+
|
| 15 |
+
## Quick Start
|
| 16 |
+
|
| 17 |
+
### Installation
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
# Install required packages
|
| 21 |
+
pip install huggingface_hub torch numpy pandas pydicom
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Basic Usage
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
# Extract embeddings from all scans
|
| 28 |
+
python extract-embeddings.py \
|
| 29 |
+
--root-dir /path/to/NLST/data \
|
| 30 |
+
--output-dir embeddings_output
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Extract Specific Patient Cohort
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# Extract only patients listed in a CSV file
|
| 37 |
+
python extract-embeddings.py \
|
| 38 |
+
--root-dir /path/to/NLST/data \
|
| 39 |
+
--pid-csv subsets/train_pids.csv \
|
| 40 |
+
--output-dir embeddings_train
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Command Line Arguments
|
| 44 |
+
|
| 45 |
+
### Required
|
| 46 |
+
- `--root-dir`: Root directory containing DICOM files (e.g., `/data/NLST`)
|
| 47 |
+
|
| 48 |
+
### Optional - Data Selection
|
| 49 |
+
- `--pid-csv`: CSV file with "pid" column to filter specific patients
|
| 50 |
+
- `--max-subjects`: Limit to N subjects (useful for testing)
|
| 51 |
+
- `--output-dir`: Output directory (default: `embeddings_output`)
|
| 52 |
+
|
| 53 |
+
### Optional - Performance Tuning
|
| 54 |
+
- `--num-gpus`: Number of GPUs to use (default: 1)
|
| 55 |
+
- `--num-parallel`: Process N scans simultaneously (default: 1, recommend 1-4)
|
| 56 |
+
- `--num-workers`: Parallel workers for directory scanning (default: 4, recommend 4-12)
|
| 57 |
+
- `--checkpoint-interval`: Save checkpoint every N scans (default: 1000)
|
| 58 |
+
|
| 59 |
+
## Expected Directory Structure
|
| 60 |
+
|
| 61 |
+
Your DICOM data should follow this structure:
|
| 62 |
+
```
|
| 63 |
+
/path/to/NLST/
|
| 64 |
+
├── NLST/
|
| 65 |
+
│ ├── <PID_1>/
|
| 66 |
+
│ │ ├── MM-DD-YYYY-NLST-LSS-<scan_id>/
|
| 67 |
+
│ │ │ ├── <series_id>/
|
| 68 |
+
│ │ │ │ ├── *.dcm
|
| 69 |
+
│ │ │ │ └── ...
|
| 70 |
+
│ │ │ └── ...
|
| 71 |
+
│ │ └── ...
|
| 72 |
+
│ ├── <PID_2>/
|
| 73 |
+
│ └── ...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Output Format
|
| 77 |
+
|
| 78 |
+
### Embeddings File: `all_embeddings.parquet`
|
| 79 |
+
|
| 80 |
+
Parquet file with columns:
|
| 81 |
+
- `case_number`: Patient ID (PID)
|
| 82 |
+
- `subject_id`: Same as case_number
|
| 83 |
+
- `scan_id`: Unique scan identifier
|
| 84 |
+
- `timepoint`: T0, T1, T2... (year-based, e.g., 1999→T0, 2000→T1)
|
| 85 |
+
- `dicom_directory`: Full path to scan directory
|
| 86 |
+
- `num_dicom_files`: Number of DICOM slices
|
| 87 |
+
- `embedding_index`: Index in embedding array
|
| 88 |
+
- `embedding`: 512-dimensional embedding array
|
| 89 |
+
|
| 90 |
+
### Metadata File: `dataset_metadata.json`
|
| 91 |
+
|
| 92 |
+
Complete metadata including:
|
| 93 |
+
- Dataset info (total scans, embedding dimensions)
|
| 94 |
+
- Model info (Sybil ensemble, extraction layer)
|
| 95 |
+
- Per-scan metadata (paths, statistics)
|
| 96 |
+
- Failed scans with error messages
|
| 97 |
+
|
| 98 |
+
## Performance Tips
|
| 99 |
+
|
| 100 |
+
### For Large Datasets (>10K scans)
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
# Use cached directory list and multi-GPU processing
|
| 104 |
+
python extract-embeddings.py \
|
| 105 |
+
--root-dir /data/NLST \
|
| 106 |
+
--num-gpus 4 \
|
| 107 |
+
--num-parallel 4 \
|
| 108 |
+
--num-workers 12 \
|
| 109 |
+
--checkpoint-interval 500
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
**Memory Requirements**: ~10GB VRAM per parallel scan
|
| 113 |
+
- `--num-parallel 1`: Safe for 16GB GPUs
|
| 114 |
+
- `--num-parallel 2`: Safe for 24GB GPUs
|
| 115 |
+
- `--num-parallel 4`: Requires 40GB+ GPUs
|
| 116 |
+
|
| 117 |
+
### For Subset Extraction (Train/Test Split)
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
# Extract training set
|
| 121 |
+
python extract-embeddings.py \
|
| 122 |
+
--root-dir /data/NLST \
|
| 123 |
+
--pid-csv train_pids.csv \
|
| 124 |
+
--output-dir embeddings_train \
|
| 125 |
+
--num-workers 12
|
| 126 |
+
|
| 127 |
+
# Extract test set
|
| 128 |
+
python extract-embeddings.py \
|
| 129 |
+
--root-dir /data/NLST \
|
| 130 |
+
--pid-csv test_pids.csv \
|
| 131 |
+
--output-dir embeddings_test \
|
| 132 |
+
--num-workers 12
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
**Speed**: With PID filtering, scanning 100K subjects for 100 PIDs takes ~5 seconds (100x speedup)
|
| 136 |
+
|
| 137 |
+
## Loading Embeddings for Training
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
import pandas as pd
|
| 141 |
+
import numpy as np
|
| 142 |
+
|
| 143 |
+
# Load embeddings
|
| 144 |
+
df = pd.read_parquet('embeddings_output/all_embeddings.parquet')
|
| 145 |
+
|
| 146 |
+
# Extract embedding array
|
| 147 |
+
embeddings = np.stack(df['embedding'].values) # Shape: (num_scans, 512)
|
| 148 |
+
|
| 149 |
+
# Access metadata
|
| 150 |
+
pids = df['case_number'].values
|
| 151 |
+
timepoints = df['timepoint'].values
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Troubleshooting
|
| 155 |
+
|
| 156 |
+
### Out of Memory (OOM) Errors
|
| 157 |
+
- Reduce `--num-parallel` to 1 or 2
|
| 158 |
+
- Use fewer GPUs with `--num-gpus 1`
|
| 159 |
+
|
| 160 |
+
### Slow Directory Scanning
|
| 161 |
+
- Increase `--num-workers` (try 8-12 for fast storage)
|
| 162 |
+
- Use `--pid-csv` to filter early (100x speedup)
|
| 163 |
+
- Rerun will use cached directory list automatically
|
| 164 |
+
|
| 165 |
+
### Missing Timepoints
|
| 166 |
+
- Timepoints are extracted from year in scan path (1999→T0, 2000→T1)
|
| 167 |
+
- If `timepoint` is None, year pattern wasn't found in path
|
| 168 |
+
- You can manually map scans to timepoints using `dicom_directory` column
|
| 169 |
+
|
| 170 |
+
### Failed Scans
|
| 171 |
+
- Check `dataset_metadata.json` for `failed_scans` section
|
| 172 |
+
- Common causes: corrupted DICOM files, insufficient slices, invalid metadata
|
| 173 |
+
|
| 174 |
+
## Federated Learning Integration
|
| 175 |
+
|
| 176 |
+
This script is designed for **privacy-preserving federated learning**:
|
| 177 |
+
|
| 178 |
+
1. **Each site runs extraction locally** on their DICOM data
|
| 179 |
+
2. **Embeddings are saved** (not raw DICOM images)
|
| 180 |
+
3. **Sites share embeddings** with federated learning system
|
| 181 |
+
4. **Central server trains model** on embeddings without accessing raw data
|
| 182 |
+
|
| 183 |
+
### Workflow for Sites
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# 1. Download extraction script
|
| 187 |
+
wget https://huggingface.co/Lab-Rasool/sybil/resolve/main/extract-embeddings.py
|
| 188 |
+
|
| 189 |
+
# 2. Extract embeddings for train/test splits
|
| 190 |
+
python extract-embeddings.py --root-dir /local/NLST --pid-csv train_pids.csv --output-dir train_embeddings
|
| 191 |
+
python extract-embeddings.py --root-dir /local/NLST --pid-csv test_pids.csv --output-dir test_embeddings
|
| 192 |
+
|
| 193 |
+
# 3. Share embeddings with federated learning system
|
| 194 |
+
# (embeddings are much smaller and preserve privacy better than raw DICOM)
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## Citation
|
| 198 |
+
|
| 199 |
+
If you use this extraction pipeline, please cite the Sybil model:
|
| 200 |
+
|
| 201 |
+
```bibtex
|
| 202 |
+
@article{sybil2023,
|
| 203 |
+
title={A Deep Learning Model to Predict Lung Cancer Risk from Chest CT Scans},
|
| 204 |
+
author={...},
|
| 205 |
+
journal={...},
|
| 206 |
+
year={2023}
|
| 207 |
+
}
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
## Support
|
| 211 |
+
|
| 212 |
+
For issues or questions:
|
| 213 |
+
- Model issues: https://huggingface.co/Lab-Rasool/sybil
|
| 214 |
+
- Federated learning: Contact your FL system administrator
|
extract-embeddings.py
ADDED
|
@@ -0,0 +1,986 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
import pydicom
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import threading
|
| 14 |
+
import multiprocessing as mp
|
| 15 |
+
|
| 16 |
+
# Download and setup model
|
| 17 |
+
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
|
| 18 |
+
sys.path.append(model_path)
|
| 19 |
+
|
| 20 |
+
from modeling_sybil_hf import SybilHFWrapper
|
| 21 |
+
from configuration_sybil import SybilConfig
|
| 22 |
+
|
| 23 |
+
def load_model(device_id=0):
|
| 24 |
+
"""
|
| 25 |
+
Load and initialize the Sybil model once.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
device_id: GPU device ID to load model on
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Initialized SybilHFWrapper model
|
| 32 |
+
"""
|
| 33 |
+
print(f"Loading Sybil model on GPU {device_id}...")
|
| 34 |
+
config = SybilConfig()
|
| 35 |
+
model = SybilHFWrapper(config)
|
| 36 |
+
|
| 37 |
+
# Move model to specific GPU
|
| 38 |
+
device = torch.device(f'cuda:{device_id}')
|
| 39 |
+
|
| 40 |
+
# CRITICAL: Set the model's internal device attribute
|
| 41 |
+
# This ensures preprocessing moves data to the correct GPU
|
| 42 |
+
model.device = device
|
| 43 |
+
|
| 44 |
+
# Move all ensemble models to the correct GPU
|
| 45 |
+
for m in model.models:
|
| 46 |
+
m.to(device)
|
| 47 |
+
m.eval()
|
| 48 |
+
|
| 49 |
+
print(f"Model loaded successfully on GPU {device_id}!")
|
| 50 |
+
print(f" Model internal device: {model.device}")
|
| 51 |
+
return model, device
|
| 52 |
+
|
| 53 |
+
def is_localizer_scan(dicom_folder):
|
| 54 |
+
"""
|
| 55 |
+
Check if a DICOM folder contains a localizer/scout scan.
|
| 56 |
+
Based on preprocessing.py logic.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tuple of (is_localizer, reason)
|
| 60 |
+
"""
|
| 61 |
+
folder_path = Path(dicom_folder)
|
| 62 |
+
folder_name = folder_path.name.lower()
|
| 63 |
+
localizer_keywords = ['localizer', 'scout', 'topogram', 'surview', 'scanogram']
|
| 64 |
+
|
| 65 |
+
# Check folder name
|
| 66 |
+
if any(keyword in folder_name for keyword in localizer_keywords):
|
| 67 |
+
return True, f"Folder name contains localizer keyword: {folder_name}"
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
dcm_files = list(folder_path.glob("*.dcm"))
|
| 71 |
+
if not dcm_files:
|
| 72 |
+
return False, "No DICOM files found"
|
| 73 |
+
|
| 74 |
+
# Check first few DICOM files for localizer metadata
|
| 75 |
+
sample_files = dcm_files[:min(3, len(dcm_files))]
|
| 76 |
+
for dcm_file in sample_files:
|
| 77 |
+
try:
|
| 78 |
+
dcm = pydicom.dcmread(str(dcm_file), stop_before_pixels=True)
|
| 79 |
+
|
| 80 |
+
# Check ImageType field
|
| 81 |
+
if hasattr(dcm, 'ImageType'):
|
| 82 |
+
image_type_str = ' '.join(str(val).lower() for val in dcm.ImageType)
|
| 83 |
+
if any(keyword in image_type_str for keyword in localizer_keywords):
|
| 84 |
+
return True, f"ImageType indicates localizer: {dcm.ImageType}"
|
| 85 |
+
|
| 86 |
+
# Check SeriesDescription field
|
| 87 |
+
if hasattr(dcm, 'SeriesDescription'):
|
| 88 |
+
if any(keyword in dcm.SeriesDescription.lower() for keyword in localizer_keywords):
|
| 89 |
+
return True, f"SeriesDescription indicates localizer: {dcm.SeriesDescription}"
|
| 90 |
+
except Exception as e:
|
| 91 |
+
continue
|
| 92 |
+
except Exception as e:
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
return False, "Not a localizer scan"
|
| 96 |
+
|
| 97 |
+
def extract_timepoint_from_path(scan_dir):
|
| 98 |
+
"""
|
| 99 |
+
Extract timepoint from scan directory path based on year.
|
| 100 |
+
1999 -> T0, 2000 -> T1, 2001 -> T2, etc.
|
| 101 |
+
|
| 102 |
+
Looks for year patterns in folder names in date format MM-DD-YYYY.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
scan_dir: Directory path string
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Timepoint string (e.g., 'T0', 'T1', 'T2') or None if not found
|
| 109 |
+
"""
|
| 110 |
+
# Split path into components
|
| 111 |
+
path_parts = scan_dir.split('/')
|
| 112 |
+
|
| 113 |
+
# Look for date patterns like "01-02-2000-NLST-LSS"
|
| 114 |
+
# Pattern: Date format MM-DD-YYYY at the start of a folder name
|
| 115 |
+
date_pattern = r'^\d{2}-\d{2}-(19\d{2}|20\d{2})'
|
| 116 |
+
|
| 117 |
+
base_year = 1999
|
| 118 |
+
|
| 119 |
+
for part in path_parts:
|
| 120 |
+
# Check for date pattern (e.g., "01-02-2000-NLST-LSS-50335")
|
| 121 |
+
match = re.match(date_pattern, part)
|
| 122 |
+
if match:
|
| 123 |
+
year = int(match.group(1))
|
| 124 |
+
if 1999 <= year <= 2010: # Reasonable range for NLST
|
| 125 |
+
timepoint_num = year - base_year
|
| 126 |
+
print(f" DEBUG: Found year {year} in '{part}' -> T{timepoint_num}")
|
| 127 |
+
return f'T{timepoint_num}'
|
| 128 |
+
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
def extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device):
|
| 132 |
+
"""
|
| 133 |
+
Extract embedding from a single ensemble model.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model_idx: Index of the model in the ensemble
|
| 137 |
+
ensemble_model: Single model from the ensemble
|
| 138 |
+
pixel_values: Preprocessed pixel values tensor (already on correct device)
|
| 139 |
+
device: Device to run on (e.g., cuda:0, cuda:1)
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
numpy array of embeddings from this model
|
| 143 |
+
"""
|
| 144 |
+
embeddings_buffer = []
|
| 145 |
+
|
| 146 |
+
def create_hook(buffer):
|
| 147 |
+
def hook(module, input, output):
|
| 148 |
+
# Capture the output of ReLU layer (before dropout)
|
| 149 |
+
buffer.append(output.detach().cpu())
|
| 150 |
+
return hook
|
| 151 |
+
|
| 152 |
+
# Register hook on the ReLU layer (this is AFTER pooling, BEFORE dropout/classification)
|
| 153 |
+
hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer))
|
| 154 |
+
|
| 155 |
+
# Run forward pass on THIS model only with keyword argument
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
_ = ensemble_model(pixel_values=pixel_values)
|
| 158 |
+
|
| 159 |
+
# Remove hook
|
| 160 |
+
hook_handle.remove()
|
| 161 |
+
|
| 162 |
+
# Get the embeddings (should be shape [1, 512])
|
| 163 |
+
if embeddings_buffer:
|
| 164 |
+
embedding = embeddings_buffer[0].numpy().squeeze()
|
| 165 |
+
print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}")
|
| 166 |
+
return embedding
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
def extract_embeddings(model, dicom_paths, device, use_parallel=True):
|
| 170 |
+
"""
|
| 171 |
+
Extract embeddings from the layer after ReLU, before Dropout.
|
| 172 |
+
Processes ensemble models in parallel for speed.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
model: Pre-loaded SybilHFWrapper model
|
| 176 |
+
dicom_paths: List of DICOM file paths
|
| 177 |
+
device: Device to run on (e.g., cuda:0, cuda:1)
|
| 178 |
+
use_parallel: If True, process ensemble models in parallel
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
numpy array of shape (512,) - averaged embeddings across ensemble
|
| 182 |
+
"""
|
| 183 |
+
# Preprocess ONCE (not 5 times!)
|
| 184 |
+
# The model's preprocessing handles moving data to the correct device
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
# Get the preprocessed input by calling the wrapper's preprocess_dicom method
|
| 187 |
+
# This returns the tensor that would be fed to each ensemble model
|
| 188 |
+
pixel_values = model.preprocess_dicom(dicom_paths)
|
| 189 |
+
|
| 190 |
+
if use_parallel:
|
| 191 |
+
# Process all ensemble models in parallel using ThreadPoolExecutor
|
| 192 |
+
all_embeddings = []
|
| 193 |
+
|
| 194 |
+
with ThreadPoolExecutor(max_workers=len(model.models)) as executor:
|
| 195 |
+
# Submit all models for parallel processing with the SAME preprocessed input
|
| 196 |
+
futures = [
|
| 197 |
+
executor.submit(extract_embedding_single_model, model_idx, ensemble_model, pixel_values, device)
|
| 198 |
+
for model_idx, ensemble_model in enumerate(model.models)
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
# Collect results as they complete
|
| 202 |
+
for future in futures:
|
| 203 |
+
embedding = future.result()
|
| 204 |
+
if embedding is not None:
|
| 205 |
+
all_embeddings.append(embedding)
|
| 206 |
+
else:
|
| 207 |
+
# Sequential processing (original implementation)
|
| 208 |
+
all_embeddings = []
|
| 209 |
+
for model_idx, ensemble_model in enumerate(model.models):
|
| 210 |
+
embedding = extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device)
|
| 211 |
+
if embedding is not None:
|
| 212 |
+
all_embeddings.append(embedding)
|
| 213 |
+
|
| 214 |
+
# Average embeddings across ensemble
|
| 215 |
+
averaged_embedding = np.mean(all_embeddings, axis=0)
|
| 216 |
+
return averaged_embedding
|
| 217 |
+
|
| 218 |
+
def check_directory_for_dicoms(dirpath):
|
| 219 |
+
"""
|
| 220 |
+
Check a single directory for valid DICOM files.
|
| 221 |
+
Returns (dirpath, num_files, subject_id, filter_reason) or None if invalid.
|
| 222 |
+
"""
|
| 223 |
+
try:
|
| 224 |
+
# Quick check: does this directory have .dcm files?
|
| 225 |
+
dcm_files = [f for f in os.listdir(dirpath)
|
| 226 |
+
if f.endswith('.dcm') and os.path.isfile(os.path.join(dirpath, f))]
|
| 227 |
+
|
| 228 |
+
if not dcm_files:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
num_files = len(dcm_files)
|
| 232 |
+
|
| 233 |
+
# Filter out scans with 1-2 DICOM files (likely localizers)
|
| 234 |
+
if num_files <= 2:
|
| 235 |
+
return (dirpath, num_files, None, 'too_few_slices')
|
| 236 |
+
|
| 237 |
+
# Check if it's a localizer scan
|
| 238 |
+
is_loc, _ = is_localizer_scan(dirpath)
|
| 239 |
+
if is_loc:
|
| 240 |
+
return (dirpath, num_files, None, 'localizer')
|
| 241 |
+
|
| 242 |
+
# Extract subject ID (PID) from path
|
| 243 |
+
# Path structure: /NLST/<PID>/<date-info>/<scan-info>
|
| 244 |
+
# Example: /NLST/106639/01-02-1999-NLST-LSS-45699/1.000000-0OPLGEHSQXAnullna...
|
| 245 |
+
path_parts = dirpath.rstrip('/').split('/')
|
| 246 |
+
|
| 247 |
+
# Find the PID: it's the part after 'NLST' directory
|
| 248 |
+
try:
|
| 249 |
+
nlst_idx = path_parts.index('NLST')
|
| 250 |
+
subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST'
|
| 251 |
+
except (ValueError, IndexError):
|
| 252 |
+
# Fallback to old logic if path structure is different
|
| 253 |
+
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
|
| 254 |
+
|
| 255 |
+
return (dirpath, num_files, subject_id, 'valid')
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
def save_directory_cache(dicom_dirs, cache_file):
|
| 261 |
+
"""
|
| 262 |
+
Save the list of DICOM directories to a cache file.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
dicom_dirs: List of directory paths
|
| 266 |
+
cache_file: Path to cache file
|
| 267 |
+
"""
|
| 268 |
+
print(f"\n💾 Saving directory cache to {cache_file}...")
|
| 269 |
+
cache_data = {
|
| 270 |
+
"timestamp": datetime.now().isoformat(),
|
| 271 |
+
"num_directories": len(dicom_dirs),
|
| 272 |
+
"directories": dicom_dirs
|
| 273 |
+
}
|
| 274 |
+
with open(cache_file, 'w') as f:
|
| 275 |
+
json.dump(cache_data, f, indent=2)
|
| 276 |
+
print(f"✓ Cache saved with {len(dicom_dirs)} directories\n")
|
| 277 |
+
|
| 278 |
+
def load_directory_cache(cache_file):
|
| 279 |
+
"""
|
| 280 |
+
Load the list of DICOM directories from a cache file.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
cache_file: Path to cache file
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
List of directory paths, or None if cache doesn't exist or is invalid
|
| 287 |
+
"""
|
| 288 |
+
if not os.path.exists(cache_file):
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
with open(cache_file, 'r') as f:
|
| 293 |
+
cache_data = json.load(f)
|
| 294 |
+
|
| 295 |
+
dicom_dirs = cache_data.get("directories", [])
|
| 296 |
+
timestamp = cache_data.get("timestamp", "unknown")
|
| 297 |
+
|
| 298 |
+
print(f"\n✓ Loaded directory cache from {cache_file}")
|
| 299 |
+
print(f" Cache created: {timestamp}")
|
| 300 |
+
print(f" Directories: {len(dicom_dirs)}\n")
|
| 301 |
+
|
| 302 |
+
return dicom_dirs
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"⚠️ Failed to load cache: {e}")
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
def find_dicom_directories(root_dir, max_subjects=None, num_workers=12, cache_file=None, filter_pids=None):
|
| 308 |
+
"""
|
| 309 |
+
Walk through directory tree and find all directories containing DICOM files.
|
| 310 |
+
Uses parallel processing for much faster scanning of large directory trees.
|
| 311 |
+
Only returns leaf directories (directories with .dcm files, not their parents).
|
| 312 |
+
Filters out localizer scans with 1-2 DICOM files.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
root_dir: Root directory to search
|
| 316 |
+
max_subjects: Optional maximum number of unique subjects to process (None = all)
|
| 317 |
+
num_workers: Number of parallel workers for directory scanning (default: 12)
|
| 318 |
+
cache_file: Optional path to cache file for saving/loading directory list
|
| 319 |
+
filter_pids: Optional set of PIDs to filter (only include these subjects)
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
List of directory paths containing .dcm files
|
| 323 |
+
"""
|
| 324 |
+
# Try to load from cache first
|
| 325 |
+
if cache_file:
|
| 326 |
+
cached_dirs = load_directory_cache(cache_file)
|
| 327 |
+
if cached_dirs is not None:
|
| 328 |
+
print("✓ Using cached directory list (skipping scan)")
|
| 329 |
+
|
| 330 |
+
# Apply PID filter if specified
|
| 331 |
+
if filter_pids:
|
| 332 |
+
print(f" Filtering to {len(filter_pids)} PIDs from CSV...")
|
| 333 |
+
filtered_dirs = []
|
| 334 |
+
for d in cached_dirs:
|
| 335 |
+
# Extract PID from path: /NLST/<PID>/<date>/<scan>
|
| 336 |
+
path_parts = d.rstrip('/').split('/')
|
| 337 |
+
try:
|
| 338 |
+
nlst_idx = path_parts.index('NLST')
|
| 339 |
+
subject_id = path_parts[nlst_idx + 1]
|
| 340 |
+
except (ValueError, IndexError):
|
| 341 |
+
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
|
| 342 |
+
|
| 343 |
+
if subject_id in filter_pids:
|
| 344 |
+
filtered_dirs.append(d)
|
| 345 |
+
print(f" ✓ Found {len(filtered_dirs)} scans matching PIDs")
|
| 346 |
+
return filtered_dirs
|
| 347 |
+
|
| 348 |
+
# Still apply max_subjects limit if specified
|
| 349 |
+
if max_subjects:
|
| 350 |
+
subjects_seen = set()
|
| 351 |
+
filtered_dirs = []
|
| 352 |
+
for d in cached_dirs:
|
| 353 |
+
# Extract PID from path: /NLST/<PID>/<date>/<scan>
|
| 354 |
+
path_parts = d.rstrip('/').split('/')
|
| 355 |
+
try:
|
| 356 |
+
nlst_idx = path_parts.index('NLST')
|
| 357 |
+
subject_id = path_parts[nlst_idx + 1]
|
| 358 |
+
except (ValueError, IndexError):
|
| 359 |
+
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
|
| 360 |
+
|
| 361 |
+
# Check if we should include this scan
|
| 362 |
+
# Add scan if: (1) already collecting this subject, OR (2) under subject limit
|
| 363 |
+
if subject_id in subjects_seen:
|
| 364 |
+
# Already collecting this subject - add this scan
|
| 365 |
+
filtered_dirs.append(d)
|
| 366 |
+
elif len(subjects_seen) < max_subjects:
|
| 367 |
+
# New subject and under limit - start collecting this subject
|
| 368 |
+
subjects_seen.add(subject_id)
|
| 369 |
+
filtered_dirs.append(d)
|
| 370 |
+
|
| 371 |
+
# Stop once we have enough subjects
|
| 372 |
+
if len(subjects_seen) >= max_subjects:
|
| 373 |
+
# Count remaining scans from these subjects
|
| 374 |
+
remaining_count = 0
|
| 375 |
+
for remaining_d in cached_dirs[cached_dirs.index(d)+1:]:
|
| 376 |
+
remaining_parts = remaining_d.rstrip('/').split('/')
|
| 377 |
+
try:
|
| 378 |
+
remaining_nlst_idx = remaining_parts.index('NLST')
|
| 379 |
+
remaining_subject_id = remaining_parts[remaining_nlst_idx + 1]
|
| 380 |
+
except (ValueError, IndexError):
|
| 381 |
+
remaining_subject_id = remaining_parts[-3] if len(remaining_parts) >= 3 else remaining_parts[-1]
|
| 382 |
+
if remaining_subject_id in subjects_seen:
|
| 383 |
+
filtered_dirs.append(remaining_d)
|
| 384 |
+
break
|
| 385 |
+
|
| 386 |
+
print(f" ✓ Limited to {len(subjects_seen)} subjects ({len(filtered_dirs)} total scans)")
|
| 387 |
+
return filtered_dirs
|
| 388 |
+
return cached_dirs
|
| 389 |
+
|
| 390 |
+
print(f"Starting parallel directory scan with {num_workers} workers...")
|
| 391 |
+
if filter_pids:
|
| 392 |
+
print(f"⚡ FAST MODE: Only scanning {len(filter_pids)} PIDs (skipping others)")
|
| 393 |
+
else:
|
| 394 |
+
print("Scanning ALL subjects (this may take a while)")
|
| 395 |
+
|
| 396 |
+
# Phase 1: Fast parallel scan to find all directories with DICOM files
|
| 397 |
+
# BUT: Skip subject directories not in filter_pids for MASSIVE speedup
|
| 398 |
+
print("\nPhase 1: Scanning filesystem for DICOM directories...")
|
| 399 |
+
start_time = datetime.now()
|
| 400 |
+
|
| 401 |
+
# Collect all directories first (fast) - WITH EARLY FILTERING
|
| 402 |
+
all_dirs = []
|
| 403 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 404 |
+
# EARLY FILTER: If we have filter_pids, only descend into matching PID directories
|
| 405 |
+
if filter_pids:
|
| 406 |
+
path_parts = dirpath.rstrip('/').split('/')
|
| 407 |
+
try:
|
| 408 |
+
nlst_idx = path_parts.index('NLST')
|
| 409 |
+
# If this is a subject directory (one level below NLST)
|
| 410 |
+
if len(path_parts) == nlst_idx + 2:
|
| 411 |
+
subject_id = path_parts[nlst_idx + 1]
|
| 412 |
+
# Skip this subject if not in filter list
|
| 413 |
+
if subject_id not in filter_pids:
|
| 414 |
+
dirnames.clear() # Don't descend into this subject's subdirs
|
| 415 |
+
continue
|
| 416 |
+
except (ValueError, IndexError):
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
# Quick check: if directory has .dcm files, add to list
|
| 420 |
+
if any(f.endswith('.dcm') for f in filenames):
|
| 421 |
+
all_dirs.append(dirpath)
|
| 422 |
+
|
| 423 |
+
print(f"Found {len(all_dirs)} potential DICOM directories in {(datetime.now() - start_time).total_seconds():.1f}s")
|
| 424 |
+
|
| 425 |
+
# Phase 2: Parallel validation and filtering
|
| 426 |
+
print(f"\nPhase 2: Validating directories in parallel ({num_workers} workers)...")
|
| 427 |
+
|
| 428 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 429 |
+
|
| 430 |
+
dicom_dirs = []
|
| 431 |
+
subjects_found = set()
|
| 432 |
+
filtered_stats = {'localizers': 0, 'too_few_slices': 0}
|
| 433 |
+
|
| 434 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 435 |
+
# Submit all directories for checking
|
| 436 |
+
future_to_dir = {executor.submit(check_directory_for_dicoms, d): d for d in all_dirs}
|
| 437 |
+
|
| 438 |
+
# Process results as they complete
|
| 439 |
+
for i, future in enumerate(as_completed(future_to_dir), 1):
|
| 440 |
+
# Print progress every 1000 dirs (more frequent for visibility)
|
| 441 |
+
if i % 1000 == 0:
|
| 442 |
+
elapsed = (datetime.now() - start_time).total_seconds()
|
| 443 |
+
rate = i / elapsed if elapsed > 0 else 0
|
| 444 |
+
remaining = (len(all_dirs) - i) / rate if rate > 0 else 0
|
| 445 |
+
print(f" [{i}/{len(all_dirs)}] Found: {len(dicom_dirs)} scans from {len(subjects_found)} PIDs | "
|
| 446 |
+
f"Filtered: {filtered_stats['localizers'] + filtered_stats['too_few_slices']} | "
|
| 447 |
+
f"ETA: {remaining/60:.1f} min")
|
| 448 |
+
|
| 449 |
+
try:
|
| 450 |
+
result = future.result()
|
| 451 |
+
if result is None:
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
dirpath, num_files, subject_id, status = result
|
| 455 |
+
|
| 456 |
+
if status == 'too_few_slices':
|
| 457 |
+
filtered_stats['too_few_slices'] += 1
|
| 458 |
+
elif status == 'localizer':
|
| 459 |
+
filtered_stats['localizers'] += 1
|
| 460 |
+
elif status == 'valid':
|
| 461 |
+
# Check PID filter
|
| 462 |
+
if filter_pids is not None and subject_id not in filter_pids:
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
# Check subject limit
|
| 466 |
+
if max_subjects is not None and subject_id not in subjects_found and len(subjects_found) >= max_subjects:
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
subjects_found.add(subject_id)
|
| 470 |
+
dicom_dirs.append(dirpath)
|
| 471 |
+
|
| 472 |
+
# Print when we find a new PID match (helpful for filtered runs)
|
| 473 |
+
if filter_pids and len(dicom_dirs) % 100 == 1:
|
| 474 |
+
print(f" ✓ Found {len(dicom_dirs)} scans so far ({len(subjects_found)} unique PIDs)")
|
| 475 |
+
|
| 476 |
+
# Stop if we've hit subject limit
|
| 477 |
+
if max_subjects is not None and len(subjects_found) >= max_subjects:
|
| 478 |
+
print(f"\n✓ Reached limit of {max_subjects} subjects. Stopping search.")
|
| 479 |
+
# Cancel remaining futures
|
| 480 |
+
for f in future_to_dir:
|
| 481 |
+
f.cancel()
|
| 482 |
+
break
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
scan_time = (datetime.now() - start_time).total_seconds()
|
| 488 |
+
|
| 489 |
+
print(f"\n{'='*80}")
|
| 490 |
+
print(f"Directory Scan Complete in {scan_time:.1f}s ({scan_time/60:.1f} minutes)")
|
| 491 |
+
print(f"{'='*80}")
|
| 492 |
+
print(f"Filtering Summary:")
|
| 493 |
+
print(f" ✅ Valid scans found: {len(dicom_dirs)}")
|
| 494 |
+
print(f" 🚫 Localizers filtered: {filtered_stats['localizers']}")
|
| 495 |
+
print(f" ⏭️ Too few slices (≤2) filtered: {filtered_stats['too_few_slices']}")
|
| 496 |
+
print(f" 📊 Unique subjects: {len(subjects_found)}")
|
| 497 |
+
print(f" ⚡ Speed: {len(all_dirs)/scan_time:.0f} dirs/second")
|
| 498 |
+
print(f"{'='*80}\n")
|
| 499 |
+
|
| 500 |
+
# Save to cache if specified
|
| 501 |
+
if cache_file:
|
| 502 |
+
save_directory_cache(dicom_dirs, cache_file)
|
| 503 |
+
|
| 504 |
+
return dicom_dirs
|
| 505 |
+
|
| 506 |
+
def prepare_scan_metadata(scan_dir):
|
| 507 |
+
"""
|
| 508 |
+
Prepare metadata for a scan without processing.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
scan_dir: Directory containing DICOM files for one scan
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
tuple: (dicom_file_paths, num_files, subject_id, scan_id)
|
| 515 |
+
"""
|
| 516 |
+
# Count DICOM files (ensure they are actual files, not directories)
|
| 517 |
+
dicom_files = [f for f in os.listdir(scan_dir)
|
| 518 |
+
if f.endswith('.dcm') and os.path.isfile(os.path.join(scan_dir, f))]
|
| 519 |
+
num_dicom_files = len(dicom_files)
|
| 520 |
+
|
| 521 |
+
if num_dicom_files == 0:
|
| 522 |
+
raise ValueError("No valid DICOM files found")
|
| 523 |
+
|
| 524 |
+
# Create list of full paths to DICOM files
|
| 525 |
+
dicom_file_paths = [os.path.join(scan_dir, f) for f in dicom_files]
|
| 526 |
+
|
| 527 |
+
# Parse directory path to extract identifiers
|
| 528 |
+
# Path structure: /NLST/<PID>/<date-info>/<scan-info>
|
| 529 |
+
path_parts = scan_dir.rstrip('/').split('/')
|
| 530 |
+
scan_id = path_parts[-1] if path_parts[-1] else path_parts[-2]
|
| 531 |
+
|
| 532 |
+
# Extract PID from path
|
| 533 |
+
try:
|
| 534 |
+
nlst_idx = path_parts.index('NLST')
|
| 535 |
+
subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST'
|
| 536 |
+
except (ValueError, IndexError):
|
| 537 |
+
# Fallback to old logic
|
| 538 |
+
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
|
| 539 |
+
|
| 540 |
+
return dicom_file_paths, num_dicom_files, subject_id, scan_id
|
| 541 |
+
|
| 542 |
+
def save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_num):
|
| 543 |
+
"""
|
| 544 |
+
Save a checkpoint of embeddings and metadata.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
all_embeddings: List of embedding arrays
|
| 548 |
+
all_metadata: List of metadata dictionaries
|
| 549 |
+
failed: List of failed scans
|
| 550 |
+
output_dir: Output directory
|
| 551 |
+
checkpoint_num: Checkpoint number
|
| 552 |
+
"""
|
| 553 |
+
print(f"\n💾 Saving checkpoint {checkpoint_num}...")
|
| 554 |
+
|
| 555 |
+
# Convert embeddings to array
|
| 556 |
+
embeddings_array = np.array(all_embeddings)
|
| 557 |
+
embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0])
|
| 558 |
+
|
| 559 |
+
# Create DataFrame
|
| 560 |
+
df_data = {
|
| 561 |
+
'case_number': [m['case_number'] for m in all_metadata],
|
| 562 |
+
'subject_id': [m['subject_id'] for m in all_metadata],
|
| 563 |
+
'scan_id': [m['scan_id'] for m in all_metadata],
|
| 564 |
+
'timepoint': [m.get('timepoint') for m in all_metadata],
|
| 565 |
+
'dicom_directory': [m['dicom_directory'] for m in all_metadata],
|
| 566 |
+
'num_dicom_files': [m['num_dicom_files'] for m in all_metadata],
|
| 567 |
+
'embedding_index': [m['embedding_index'] for m in all_metadata],
|
| 568 |
+
'embedding': list(embeddings_array)
|
| 569 |
+
}
|
| 570 |
+
df = pd.DataFrame(df_data)
|
| 571 |
+
|
| 572 |
+
# Save checkpoint parquet
|
| 573 |
+
checkpoint_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_embeddings.parquet")
|
| 574 |
+
df.to_parquet(checkpoint_path, index=False, compression='snappy')
|
| 575 |
+
print(f" ✓ Saved embeddings checkpoint: {checkpoint_path}")
|
| 576 |
+
|
| 577 |
+
# Save checkpoint metadata
|
| 578 |
+
checkpoint_metadata = {
|
| 579 |
+
"checkpoint_num": checkpoint_num,
|
| 580 |
+
"timestamp": datetime.now().isoformat(),
|
| 581 |
+
"total_scans": len(all_embeddings),
|
| 582 |
+
"failed_scans": len(failed),
|
| 583 |
+
"embedding_shape": list(embeddings_array.shape),
|
| 584 |
+
"scans": all_metadata,
|
| 585 |
+
"failed_scans": failed
|
| 586 |
+
}
|
| 587 |
+
metadata_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_metadata.json")
|
| 588 |
+
with open(metadata_path, 'w') as f:
|
| 589 |
+
json.dump(checkpoint_metadata, f, indent=2)
|
| 590 |
+
print(f" ✓ Saved metadata checkpoint: {metadata_path}")
|
| 591 |
+
print(f"💾 Checkpoint {checkpoint_num} complete!\n")
|
| 592 |
+
|
| 593 |
+
def process_scan(model, device, scan_dir):
|
| 594 |
+
"""
|
| 595 |
+
Process a single scan directory and extract embeddings.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
model: Pre-loaded SybilHFWrapper model
|
| 599 |
+
device: Device to run on (e.g., cuda:0, cuda:1)
|
| 600 |
+
scan_dir: Directory containing DICOM files for one scan
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
tuple: (embeddings, scan_metadata)
|
| 604 |
+
"""
|
| 605 |
+
dicom_file_paths, num_dicom_files, subject_id, scan_id = prepare_scan_metadata(scan_dir)
|
| 606 |
+
|
| 607 |
+
print(f"\nProcessing: {scan_dir}")
|
| 608 |
+
print(f"DICOM files: {num_dicom_files}")
|
| 609 |
+
|
| 610 |
+
# Extract embeddings
|
| 611 |
+
embeddings = extract_embeddings(model, dicom_file_paths, device)
|
| 612 |
+
|
| 613 |
+
print(f"Embedding shape: {embeddings.shape}")
|
| 614 |
+
|
| 615 |
+
# Extract timepoint from path (e.g., 1999 -> T0, 2000 -> T1)
|
| 616 |
+
timepoint = extract_timepoint_from_path(scan_dir)
|
| 617 |
+
if timepoint:
|
| 618 |
+
print(f"Timepoint: {timepoint}")
|
| 619 |
+
else:
|
| 620 |
+
print(f"Timepoint: Not detected")
|
| 621 |
+
|
| 622 |
+
# Create metadata for this scan
|
| 623 |
+
scan_metadata = {
|
| 624 |
+
"case_number": subject_id, # Case number (e.g., 205749)
|
| 625 |
+
"subject_id": subject_id,
|
| 626 |
+
"scan_id": scan_id,
|
| 627 |
+
"timepoint": timepoint, # T0, T1, T2, etc. or None
|
| 628 |
+
"dicom_directory": scan_dir,
|
| 629 |
+
"num_dicom_files": num_dicom_files,
|
| 630 |
+
"embedding_index": None, # Will be set later
|
| 631 |
+
"statistics": {
|
| 632 |
+
"mean": float(np.mean(embeddings)),
|
| 633 |
+
"std": float(np.std(embeddings)),
|
| 634 |
+
"min": float(np.min(embeddings)),
|
| 635 |
+
"max": float(np.max(embeddings))
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
return embeddings, scan_metadata
|
| 640 |
+
|
| 641 |
+
# Main execution
|
| 642 |
+
if __name__ == "__main__":
|
| 643 |
+
import argparse
|
| 644 |
+
|
| 645 |
+
# Parse command line arguments
|
| 646 |
+
parser = argparse.ArgumentParser(description='Extract Sybil embeddings from DICOM scans')
|
| 647 |
+
|
| 648 |
+
# Input/Output
|
| 649 |
+
parser.add_argument('--root-dir', type=str, required=True,
|
| 650 |
+
help='Root directory containing DICOM files (e.g., /path/to/NLST)')
|
| 651 |
+
parser.add_argument('--pid-csv', type=str, default=None,
|
| 652 |
+
help='CSV file with "pid" column to filter subjects (e.g., subsets/hybridModels-train.csv)')
|
| 653 |
+
parser.add_argument('--output-dir', type=str, default='embeddings_output',
|
| 654 |
+
help='Output directory for embeddings (default: embeddings_output)')
|
| 655 |
+
parser.add_argument('--max-subjects', type=int, default=None,
|
| 656 |
+
help='Maximum number of subjects to process (for testing)')
|
| 657 |
+
|
| 658 |
+
# Performance tuning
|
| 659 |
+
parser.add_argument('--num-gpus', type=int, default=1,
|
| 660 |
+
help='Number of GPUs to use (default: 1)')
|
| 661 |
+
parser.add_argument('--num-parallel', type=int, default=1,
|
| 662 |
+
help='Number of parallel scans to process simultaneously (default: 1, recommended: 1-4 depending on GPU memory)')
|
| 663 |
+
parser.add_argument('--num-workers', type=int, default=4,
|
| 664 |
+
help='Number of parallel workers for directory scanning (default: 4, recommended: 4-12 depending on storage speed)')
|
| 665 |
+
parser.add_argument('--checkpoint-interval', type=int, default=1000,
|
| 666 |
+
help='Save checkpoint every N scans (default: 1000)')
|
| 667 |
+
|
| 668 |
+
args = parser.parse_args()
|
| 669 |
+
|
| 670 |
+
# ==========================================
|
| 671 |
+
# CONFIGURATION
|
| 672 |
+
# ==========================================
|
| 673 |
+
root_dir = args.root_dir
|
| 674 |
+
output_dir = args.output_dir
|
| 675 |
+
max_subjects = args.max_subjects
|
| 676 |
+
num_gpus = args.num_gpus
|
| 677 |
+
num_parallel_scans = args.num_parallel
|
| 678 |
+
num_scan_workers = args.num_workers
|
| 679 |
+
checkpoint_interval = args.checkpoint_interval
|
| 680 |
+
|
| 681 |
+
# Always use the main cache file from the full run
|
| 682 |
+
main_cache = "embeddings_output_full/directory_cache.json"
|
| 683 |
+
if os.path.exists(main_cache):
|
| 684 |
+
cache_file = main_cache
|
| 685 |
+
print(f"✓ Found main directory cache: {main_cache}")
|
| 686 |
+
else:
|
| 687 |
+
cache_file = os.path.join(output_dir, "directory_cache.json")
|
| 688 |
+
|
| 689 |
+
# Verify root directory exists
|
| 690 |
+
if not os.path.exists(root_dir):
|
| 691 |
+
raise ValueError(f"Root directory does not exist: {root_dir}")
|
| 692 |
+
|
| 693 |
+
# Load PIDs from CSV if provided
|
| 694 |
+
filter_pids = None
|
| 695 |
+
if args.pid_csv:
|
| 696 |
+
print(f"Loading subject PIDs from: {args.pid_csv}")
|
| 697 |
+
import pandas as pd
|
| 698 |
+
csv_data = pd.read_csv(args.pid_csv)
|
| 699 |
+
filter_pids = set(str(pid) for pid in csv_data['pid'].unique())
|
| 700 |
+
print(f" Found {len(filter_pids)} unique PIDs to extract")
|
| 701 |
+
print(f" Examples: {list(filter_pids)[:5]}")
|
| 702 |
+
|
| 703 |
+
# Create output directory
|
| 704 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 705 |
+
|
| 706 |
+
# Print configuration
|
| 707 |
+
print(f"\n{'='*80}")
|
| 708 |
+
print(f"CONFIGURATION")
|
| 709 |
+
print(f"{'='*80}")
|
| 710 |
+
print(f"Root directory: {root_dir}")
|
| 711 |
+
print(f"Output directory: {output_dir}")
|
| 712 |
+
print(f"Number of GPUs: {num_gpus}")
|
| 713 |
+
print(f"Parallel scans: {num_parallel_scans} (recommended: 1-4 depending on GPU memory)")
|
| 714 |
+
print(f"Directory scan workers: {num_scan_workers} (recommended: 4-12 depending on storage)")
|
| 715 |
+
print(f"Checkpoint interval: {checkpoint_interval} scans")
|
| 716 |
+
if filter_pids:
|
| 717 |
+
print(f"Filtering to: {len(filter_pids)} PIDs from CSV")
|
| 718 |
+
if max_subjects:
|
| 719 |
+
print(f"Max subjects: {max_subjects}")
|
| 720 |
+
print(f"{'='*80}\n")
|
| 721 |
+
|
| 722 |
+
# Warning about memory requirements
|
| 723 |
+
if num_parallel_scans > 1:
|
| 724 |
+
estimated_vram = (num_parallel_scans // num_gpus) * 10
|
| 725 |
+
print(f"⚠️ MEMORY WARNING:")
|
| 726 |
+
print(f" Parallel processing requires ~{estimated_vram}GB VRAM per GPU")
|
| 727 |
+
print(f" If you encounter OOM errors, reduce --num-parallel to 1-2")
|
| 728 |
+
print(f" Current: {num_parallel_scans} scans across {num_gpus} GPU(s)\n")
|
| 729 |
+
|
| 730 |
+
# Find all directories containing DICOM files (FAST with parallel processing!)
|
| 731 |
+
# Will use cached directory list if available, otherwise scan and save cache
|
| 732 |
+
dicom_dirs = find_dicom_directories(root_dir, max_subjects=max_subjects,
|
| 733 |
+
num_workers=num_scan_workers, cache_file=cache_file,
|
| 734 |
+
filter_pids=filter_pids)
|
| 735 |
+
|
| 736 |
+
if len(dicom_dirs) == 0:
|
| 737 |
+
raise ValueError(f"No directories with DICOM files found in {root_dir}")
|
| 738 |
+
|
| 739 |
+
print(f"\n{'='*80}")
|
| 740 |
+
print(f"Found {len(dicom_dirs)} directories containing DICOM files")
|
| 741 |
+
print(f"{'='*80}\n")
|
| 742 |
+
|
| 743 |
+
# Detect and load models on multiple GPUs
|
| 744 |
+
print(f"🎮 Detected {num_gpus} GPU(s)")
|
| 745 |
+
print(f"🚀 Will process {num_parallel_scans} scans in parallel ({num_parallel_scans // num_gpus} per GPU)")
|
| 746 |
+
print(f"💾 Checkpoints will be saved every {checkpoint_interval} scans\n")
|
| 747 |
+
|
| 748 |
+
# Load models on each GPU
|
| 749 |
+
models_and_devices = []
|
| 750 |
+
for gpu_id in range(num_gpus):
|
| 751 |
+
model, device = load_model(gpu_id)
|
| 752 |
+
models_and_devices.append((model, device, gpu_id))
|
| 753 |
+
|
| 754 |
+
# Process each scan directory and collect all embeddings
|
| 755 |
+
all_embeddings = []
|
| 756 |
+
all_metadata = []
|
| 757 |
+
failed = []
|
| 758 |
+
checkpoint_counter = 0
|
| 759 |
+
|
| 760 |
+
if num_parallel_scans > 1:
|
| 761 |
+
# Parallel processing of multiple scans across multiple GPUs
|
| 762 |
+
print(f"Processing {num_parallel_scans} scans in parallel across {num_gpus} GPU(s)...")
|
| 763 |
+
print(f"Note: This requires ~{(num_parallel_scans // num_gpus) * 10}GB VRAM per GPU.\n")
|
| 764 |
+
|
| 765 |
+
from functools import partial
|
| 766 |
+
from concurrent.futures import as_completed
|
| 767 |
+
|
| 768 |
+
# Process scans in batches for checkpoint saving
|
| 769 |
+
batch_size = checkpoint_interval
|
| 770 |
+
num_batches = (len(dicom_dirs) + batch_size - 1) // batch_size
|
| 771 |
+
|
| 772 |
+
for batch_idx in range(num_batches):
|
| 773 |
+
start_idx = batch_idx * batch_size
|
| 774 |
+
end_idx = min(start_idx + batch_size, len(dicom_dirs))
|
| 775 |
+
batch_dirs = dicom_dirs[start_idx:end_idx]
|
| 776 |
+
|
| 777 |
+
print(f"\n{'='*80}")
|
| 778 |
+
print(f"Processing batch {batch_idx + 1}/{num_batches} (scans {start_idx + 1} to {end_idx})")
|
| 779 |
+
print(f"{'='*80}\n")
|
| 780 |
+
|
| 781 |
+
# Use ThreadPoolExecutor for parallel scan processing
|
| 782 |
+
# IMPORTANT: max_workers limits concurrent execution to prevent OOM
|
| 783 |
+
with ThreadPoolExecutor(max_workers=num_parallel_scans) as executor:
|
| 784 |
+
# Submit scans in controlled batches to avoid memory issues
|
| 785 |
+
# We submit only max_workers scans at once, then submit more as they complete
|
| 786 |
+
future_to_info = {}
|
| 787 |
+
scan_queue = list(enumerate(batch_dirs))
|
| 788 |
+
scans_submitted = 0
|
| 789 |
+
|
| 790 |
+
# Submit initial batch (up to max_workers scans)
|
| 791 |
+
while scan_queue and scans_submitted < num_parallel_scans:
|
| 792 |
+
i, scan_dir = scan_queue.pop(0)
|
| 793 |
+
# Select GPU in round-robin fashion
|
| 794 |
+
gpu_idx = i % num_gpus
|
| 795 |
+
model, device, gpu_id = models_and_devices[gpu_idx]
|
| 796 |
+
|
| 797 |
+
# Create partial function with model and device
|
| 798 |
+
process_func = partial(process_scan, model, device)
|
| 799 |
+
future = executor.submit(process_func, scan_dir)
|
| 800 |
+
future_to_info[future] = (start_idx + i + 1, scan_dir, gpu_id)
|
| 801 |
+
scans_submitted += 1
|
| 802 |
+
|
| 803 |
+
# Process results as they complete and submit new scans
|
| 804 |
+
while future_to_info:
|
| 805 |
+
# Wait for next completion
|
| 806 |
+
done_futures = []
|
| 807 |
+
for future in list(future_to_info.keys()):
|
| 808 |
+
if future.done():
|
| 809 |
+
done_futures.append(future)
|
| 810 |
+
|
| 811 |
+
if not done_futures:
|
| 812 |
+
import time
|
| 813 |
+
time.sleep(0.1)
|
| 814 |
+
continue
|
| 815 |
+
|
| 816 |
+
# Process completed futures
|
| 817 |
+
for future in done_futures:
|
| 818 |
+
scan_num, scan_dir, gpu_id = future_to_info.pop(future)
|
| 819 |
+
try:
|
| 820 |
+
print(f"[{scan_num}/{len(dicom_dirs)}] Processing on GPU {gpu_id}...")
|
| 821 |
+
embeddings, scan_metadata = future.result()
|
| 822 |
+
|
| 823 |
+
# Set the index for this scan
|
| 824 |
+
scan_metadata["embedding_index"] = len(all_embeddings)
|
| 825 |
+
|
| 826 |
+
# Collect embeddings and metadata
|
| 827 |
+
all_embeddings.append(embeddings)
|
| 828 |
+
all_metadata.append(scan_metadata)
|
| 829 |
+
|
| 830 |
+
except Exception as e:
|
| 831 |
+
print(f"ERROR processing {scan_dir}: {e}")
|
| 832 |
+
failed.append({"scan_dir": scan_dir, "error": str(e)})
|
| 833 |
+
|
| 834 |
+
# Submit next scan from queue
|
| 835 |
+
if scan_queue:
|
| 836 |
+
i, next_scan_dir = scan_queue.pop(0)
|
| 837 |
+
gpu_idx = i % num_gpus
|
| 838 |
+
model, device, gpu_id = models_and_devices[gpu_idx]
|
| 839 |
+
|
| 840 |
+
process_func = partial(process_scan, model, device)
|
| 841 |
+
new_future = executor.submit(process_func, next_scan_dir)
|
| 842 |
+
future_to_info[new_future] = (start_idx + i + 1, next_scan_dir, gpu_id)
|
| 843 |
+
|
| 844 |
+
# Save checkpoint after each batch
|
| 845 |
+
checkpoint_counter += 1
|
| 846 |
+
save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter)
|
| 847 |
+
|
| 848 |
+
print(f"Progress: {len(all_embeddings)}/{len(dicom_dirs)} scans completed "
|
| 849 |
+
f"({len(all_embeddings)/len(dicom_dirs)*100:.1f}%)\n")
|
| 850 |
+
else:
|
| 851 |
+
# Sequential processing (original behavior)
|
| 852 |
+
model, device, gpu_id = models_and_devices[0] # Use first GPU
|
| 853 |
+
|
| 854 |
+
for i, scan_dir in enumerate(dicom_dirs, 1):
|
| 855 |
+
try:
|
| 856 |
+
print(f"\n[{i}/{len(dicom_dirs)}] Processing scan...")
|
| 857 |
+
|
| 858 |
+
# Process scan and get results
|
| 859 |
+
embeddings, scan_metadata = process_scan(model, device, scan_dir)
|
| 860 |
+
|
| 861 |
+
# Set the index for this scan
|
| 862 |
+
scan_metadata["embedding_index"] = len(all_embeddings)
|
| 863 |
+
|
| 864 |
+
# Collect embeddings and metadata
|
| 865 |
+
all_embeddings.append(embeddings)
|
| 866 |
+
all_metadata.append(scan_metadata)
|
| 867 |
+
|
| 868 |
+
# Save checkpoint every checkpoint_interval scans
|
| 869 |
+
if i % checkpoint_interval == 0:
|
| 870 |
+
checkpoint_counter += 1
|
| 871 |
+
save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter)
|
| 872 |
+
|
| 873 |
+
except Exception as e:
|
| 874 |
+
print(f"ERROR processing {scan_dir}: {e}")
|
| 875 |
+
failed.append({"scan_dir": scan_dir, "error": str(e)})
|
| 876 |
+
|
| 877 |
+
# Convert embeddings list to numpy array
|
| 878 |
+
# Shape will be (num_scans, embedding_dim)
|
| 879 |
+
embeddings_array = np.array(all_embeddings)
|
| 880 |
+
embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0])
|
| 881 |
+
|
| 882 |
+
# Create DataFrame with embeddings and metadata for Parquet
|
| 883 |
+
# Store embeddings as a single array column
|
| 884 |
+
df_data = {
|
| 885 |
+
'case_number': [m['case_number'] for m in all_metadata],
|
| 886 |
+
'subject_id': [m['subject_id'] for m in all_metadata],
|
| 887 |
+
'scan_id': [m['scan_id'] for m in all_metadata],
|
| 888 |
+
'timepoint': [m.get('timepoint') for m in all_metadata], # T0, T1, T2, etc.
|
| 889 |
+
'dicom_directory': [m['dicom_directory'] for m in all_metadata],
|
| 890 |
+
'num_dicom_files': [m['num_dicom_files'] for m in all_metadata],
|
| 891 |
+
'embedding_index': [m['embedding_index'] for m in all_metadata],
|
| 892 |
+
'embedding': list(embeddings_array) # Store as list of arrays
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
# Create DataFrame
|
| 896 |
+
df = pd.DataFrame(df_data)
|
| 897 |
+
|
| 898 |
+
# Save final complete file as Parquet
|
| 899 |
+
embeddings_filename = "all_embeddings.parquet"
|
| 900 |
+
embeddings_path = os.path.join(output_dir, embeddings_filename)
|
| 901 |
+
df.to_parquet(embeddings_path, index=False, compression='snappy')
|
| 902 |
+
print(f"\n✅ Saved FINAL embeddings to Parquet: {embeddings_path}")
|
| 903 |
+
|
| 904 |
+
# Create comprehensive metadata JSON
|
| 905 |
+
dataset_metadata = {
|
| 906 |
+
"dataset_info": {
|
| 907 |
+
"root_directory": root_dir,
|
| 908 |
+
"total_scans": len(all_embeddings),
|
| 909 |
+
"failed_scans": len(failed),
|
| 910 |
+
"embedding_shape": list(embeddings_array.shape),
|
| 911 |
+
"embedding_dim": embedding_dim,
|
| 912 |
+
"extraction_timestamp": datetime.now().isoformat(),
|
| 913 |
+
"file_format": "parquet"
|
| 914 |
+
},
|
| 915 |
+
"model_info": {
|
| 916 |
+
"model": "Lab-Rasool/sybil",
|
| 917 |
+
"layer": "after_relu_before_dropout",
|
| 918 |
+
"ensemble_averaged": True,
|
| 919 |
+
"num_ensemble_models": 5
|
| 920 |
+
},
|
| 921 |
+
"embeddings_file": embeddings_filename,
|
| 922 |
+
"parquet_schema": {
|
| 923 |
+
"metadata_columns": ["case_number", "subject_id", "scan_id", "timepoint", "dicom_directory", "num_dicom_files", "embedding_index"],
|
| 924 |
+
"embedding_column": "embedding",
|
| 925 |
+
"embedding_shape": f"({embedding_dim},)",
|
| 926 |
+
"total_columns": 8,
|
| 927 |
+
"timepoint_info": "T0=1999, T1=2000, T2=2001, etc. Extracted from year in path. Can be None if not detected."
|
| 928 |
+
},
|
| 929 |
+
"filtering_info": {
|
| 930 |
+
"localizer_detection": "Scans identified as localizers (by folder name or DICOM metadata) are filtered out",
|
| 931 |
+
"min_slices": "Scans with ≤2 DICOM files are filtered out (likely localizers)",
|
| 932 |
+
"accepted_scans": len(all_embeddings)
|
| 933 |
+
},
|
| 934 |
+
"scans": all_metadata,
|
| 935 |
+
"failed_scans": failed
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
metadata_filename = "dataset_metadata.json"
|
| 939 |
+
metadata_path = os.path.join(output_dir, metadata_filename)
|
| 940 |
+
with open(metadata_path, 'w') as f:
|
| 941 |
+
json.dump(dataset_metadata, f, indent=2)
|
| 942 |
+
print(f"✅ Saved FINAL metadata: {metadata_path}")
|
| 943 |
+
|
| 944 |
+
# Summary
|
| 945 |
+
print(f"\n{'='*80}")
|
| 946 |
+
print(f"PROCESSING COMPLETE")
|
| 947 |
+
print(f"{'='*80}")
|
| 948 |
+
print(f"Successfully processed: {len(all_embeddings)}/{len(dicom_dirs)} scans")
|
| 949 |
+
print(f"Failed: {len(failed)}/{len(dicom_dirs)} scans")
|
| 950 |
+
print(f"\nEmbeddings array shape: {embeddings_array.shape}")
|
| 951 |
+
print(f"Saved embeddings to: {embeddings_path}")
|
| 952 |
+
print(f"Saved metadata to: {metadata_path}")
|
| 953 |
+
|
| 954 |
+
# Timepoint summary
|
| 955 |
+
timepoint_counts = {}
|
| 956 |
+
for m in all_metadata:
|
| 957 |
+
tp = m.get('timepoint', 'Unknown')
|
| 958 |
+
timepoint_counts[tp] = timepoint_counts.get(tp, 0) + 1
|
| 959 |
+
|
| 960 |
+
if timepoint_counts:
|
| 961 |
+
print(f"\n📅 Timepoint Distribution:")
|
| 962 |
+
for tp in sorted(timepoint_counts.keys(), key=lambda x: (x is None, x)):
|
| 963 |
+
count = timepoint_counts[tp]
|
| 964 |
+
if tp is None:
|
| 965 |
+
print(f" Unknown/Not detected: {count} scans")
|
| 966 |
+
else:
|
| 967 |
+
print(f" {tp}: {count} scans")
|
| 968 |
+
|
| 969 |
+
if failed:
|
| 970 |
+
print(f"\nFailed scans: {len(failed)}")
|
| 971 |
+
for fail_info in failed[:5]: # Show first 5 failures
|
| 972 |
+
print(f" - {fail_info['scan_dir']}")
|
| 973 |
+
print(f" Error: {fail_info['error']}")
|
| 974 |
+
if len(failed) > 5:
|
| 975 |
+
print(f" ... and {len(failed) - 5} more failures")
|
| 976 |
+
|
| 977 |
+
print(f"\n{'='*80}")
|
| 978 |
+
print(f"For downstream training, load embeddings with:")
|
| 979 |
+
print(f" import pandas as pd")
|
| 980 |
+
print(f" import numpy as np")
|
| 981 |
+
print(f" df = pd.read_parquet('{embeddings_path}')")
|
| 982 |
+
print(f" # Total rows: {len(df)}, Total columns: {len(df.columns)}")
|
| 983 |
+
print(f" # Extract embeddings array: embeddings = np.stack(df['embedding'].values)")
|
| 984 |
+
print(f" # Shape: {embeddings_array.shape}")
|
| 985 |
+
print(f" # Access individual: df.loc[0, 'embedding'] -> array of shape ({embedding_dim},)")
|
| 986 |
+
print(f"{'='*80}")
|