Niko.Koutsoubis commited on
Commit
a091733
·
1 Parent(s): 88d9d81

Add embedding extraction pipeline for federated learning

Browse files
Files changed (2) hide show
  1. EXTRACTION_README.md +214 -0
  2. extract-embeddings.py +986 -0
EXTRACTION_README.md ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sybil Embedding Extraction Pipeline
2
+
3
+ This script extracts 512-dimensional embeddings from chest CT DICOM scans using the Sybil lung cancer risk prediction model. It's designed for **federated learning** deployments where sites need to generate embeddings locally without sharing raw medical images.
4
+
5
+ ## Features
6
+
7
+ - ✅ **Automatic Model Download**: Downloads Sybil model from HuggingFace automatically
8
+ - ✅ **Multi-GPU Support**: Process scans in parallel across multiple GPUs
9
+ - ✅ **Smart Filtering**: Automatically filters out localizer/scout scans
10
+ - ✅ **PID-Based Extraction**: Extract embeddings for specific patient cohorts
11
+ - ✅ **Checkpoint System**: Save progress every N scans to prevent data loss
12
+ - ✅ **Timepoint Detection**: Automatically detects T0, T1, T2... from scan dates
13
+ - ✅ **Directory Caching**: Cache directory scans for 100x faster reruns
14
+
15
+ ## Quick Start
16
+
17
+ ### Installation
18
+
19
+ ```bash
20
+ # Install required packages
21
+ pip install huggingface_hub torch numpy pandas pydicom
22
+ ```
23
+
24
+ ### Basic Usage
25
+
26
+ ```bash
27
+ # Extract embeddings from all scans
28
+ python extract-embeddings.py \
29
+ --root-dir /path/to/NLST/data \
30
+ --output-dir embeddings_output
31
+ ```
32
+
33
+ ### Extract Specific Patient Cohort
34
+
35
+ ```bash
36
+ # Extract only patients listed in a CSV file
37
+ python extract-embeddings.py \
38
+ --root-dir /path/to/NLST/data \
39
+ --pid-csv subsets/train_pids.csv \
40
+ --output-dir embeddings_train
41
+ ```
42
+
43
+ ## Command Line Arguments
44
+
45
+ ### Required
46
+ - `--root-dir`: Root directory containing DICOM files (e.g., `/data/NLST`)
47
+
48
+ ### Optional - Data Selection
49
+ - `--pid-csv`: CSV file with "pid" column to filter specific patients
50
+ - `--max-subjects`: Limit to N subjects (useful for testing)
51
+ - `--output-dir`: Output directory (default: `embeddings_output`)
52
+
53
+ ### Optional - Performance Tuning
54
+ - `--num-gpus`: Number of GPUs to use (default: 1)
55
+ - `--num-parallel`: Process N scans simultaneously (default: 1, recommend 1-4)
56
+ - `--num-workers`: Parallel workers for directory scanning (default: 4, recommend 4-12)
57
+ - `--checkpoint-interval`: Save checkpoint every N scans (default: 1000)
58
+
59
+ ## Expected Directory Structure
60
+
61
+ Your DICOM data should follow this structure:
62
+ ```
63
+ /path/to/NLST/
64
+ ├── NLST/
65
+ │ ├── <PID_1>/
66
+ │ │ ├── MM-DD-YYYY-NLST-LSS-<scan_id>/
67
+ │ │ │ ├── <series_id>/
68
+ │ │ │ │ ├── *.dcm
69
+ │ │ │ │ └── ...
70
+ │ │ │ └── ...
71
+ │ │ └── ...
72
+ │ ├── <PID_2>/
73
+ │ └── ...
74
+ ```
75
+
76
+ ## Output Format
77
+
78
+ ### Embeddings File: `all_embeddings.parquet`
79
+
80
+ Parquet file with columns:
81
+ - `case_number`: Patient ID (PID)
82
+ - `subject_id`: Same as case_number
83
+ - `scan_id`: Unique scan identifier
84
+ - `timepoint`: T0, T1, T2... (year-based, e.g., 1999→T0, 2000→T1)
85
+ - `dicom_directory`: Full path to scan directory
86
+ - `num_dicom_files`: Number of DICOM slices
87
+ - `embedding_index`: Index in embedding array
88
+ - `embedding`: 512-dimensional embedding array
89
+
90
+ ### Metadata File: `dataset_metadata.json`
91
+
92
+ Complete metadata including:
93
+ - Dataset info (total scans, embedding dimensions)
94
+ - Model info (Sybil ensemble, extraction layer)
95
+ - Per-scan metadata (paths, statistics)
96
+ - Failed scans with error messages
97
+
98
+ ## Performance Tips
99
+
100
+ ### For Large Datasets (>10K scans)
101
+
102
+ ```bash
103
+ # Use cached directory list and multi-GPU processing
104
+ python extract-embeddings.py \
105
+ --root-dir /data/NLST \
106
+ --num-gpus 4 \
107
+ --num-parallel 4 \
108
+ --num-workers 12 \
109
+ --checkpoint-interval 500
110
+ ```
111
+
112
+ **Memory Requirements**: ~10GB VRAM per parallel scan
113
+ - `--num-parallel 1`: Safe for 16GB GPUs
114
+ - `--num-parallel 2`: Safe for 24GB GPUs
115
+ - `--num-parallel 4`: Requires 40GB+ GPUs
116
+
117
+ ### For Subset Extraction (Train/Test Split)
118
+
119
+ ```bash
120
+ # Extract training set
121
+ python extract-embeddings.py \
122
+ --root-dir /data/NLST \
123
+ --pid-csv train_pids.csv \
124
+ --output-dir embeddings_train \
125
+ --num-workers 12
126
+
127
+ # Extract test set
128
+ python extract-embeddings.py \
129
+ --root-dir /data/NLST \
130
+ --pid-csv test_pids.csv \
131
+ --output-dir embeddings_test \
132
+ --num-workers 12
133
+ ```
134
+
135
+ **Speed**: With PID filtering, scanning 100K subjects for 100 PIDs takes ~5 seconds (100x speedup)
136
+
137
+ ## Loading Embeddings for Training
138
+
139
+ ```python
140
+ import pandas as pd
141
+ import numpy as np
142
+
143
+ # Load embeddings
144
+ df = pd.read_parquet('embeddings_output/all_embeddings.parquet')
145
+
146
+ # Extract embedding array
147
+ embeddings = np.stack(df['embedding'].values) # Shape: (num_scans, 512)
148
+
149
+ # Access metadata
150
+ pids = df['case_number'].values
151
+ timepoints = df['timepoint'].values
152
+ ```
153
+
154
+ ## Troubleshooting
155
+
156
+ ### Out of Memory (OOM) Errors
157
+ - Reduce `--num-parallel` to 1 or 2
158
+ - Use fewer GPUs with `--num-gpus 1`
159
+
160
+ ### Slow Directory Scanning
161
+ - Increase `--num-workers` (try 8-12 for fast storage)
162
+ - Use `--pid-csv` to filter early (100x speedup)
163
+ - Rerun will use cached directory list automatically
164
+
165
+ ### Missing Timepoints
166
+ - Timepoints are extracted from year in scan path (1999→T0, 2000→T1)
167
+ - If `timepoint` is None, year pattern wasn't found in path
168
+ - You can manually map scans to timepoints using `dicom_directory` column
169
+
170
+ ### Failed Scans
171
+ - Check `dataset_metadata.json` for `failed_scans` section
172
+ - Common causes: corrupted DICOM files, insufficient slices, invalid metadata
173
+
174
+ ## Federated Learning Integration
175
+
176
+ This script is designed for **privacy-preserving federated learning**:
177
+
178
+ 1. **Each site runs extraction locally** on their DICOM data
179
+ 2. **Embeddings are saved** (not raw DICOM images)
180
+ 3. **Sites share embeddings** with federated learning system
181
+ 4. **Central server trains model** on embeddings without accessing raw data
182
+
183
+ ### Workflow for Sites
184
+
185
+ ```bash
186
+ # 1. Download extraction script
187
+ wget https://huggingface.co/Lab-Rasool/sybil/resolve/main/extract-embeddings.py
188
+
189
+ # 2. Extract embeddings for train/test splits
190
+ python extract-embeddings.py --root-dir /local/NLST --pid-csv train_pids.csv --output-dir train_embeddings
191
+ python extract-embeddings.py --root-dir /local/NLST --pid-csv test_pids.csv --output-dir test_embeddings
192
+
193
+ # 3. Share embeddings with federated learning system
194
+ # (embeddings are much smaller and preserve privacy better than raw DICOM)
195
+ ```
196
+
197
+ ## Citation
198
+
199
+ If you use this extraction pipeline, please cite the Sybil model:
200
+
201
+ ```bibtex
202
+ @article{sybil2023,
203
+ title={A Deep Learning Model to Predict Lung Cancer Risk from Chest CT Scans},
204
+ author={...},
205
+ journal={...},
206
+ year={2023}
207
+ }
208
+ ```
209
+
210
+ ## Support
211
+
212
+ For issues or questions:
213
+ - Model issues: https://huggingface.co/Lab-Rasool/sybil
214
+ - Federated learning: Contact your FL system administrator
extract-embeddings.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ import sys
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ import json
8
+ import re
9
+ import pydicom
10
+ from datetime import datetime
11
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
12
+ from pathlib import Path
13
+ import threading
14
+ import multiprocessing as mp
15
+
16
+ # Download and setup model
17
+ model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
18
+ sys.path.append(model_path)
19
+
20
+ from modeling_sybil_hf import SybilHFWrapper
21
+ from configuration_sybil import SybilConfig
22
+
23
+ def load_model(device_id=0):
24
+ """
25
+ Load and initialize the Sybil model once.
26
+
27
+ Args:
28
+ device_id: GPU device ID to load model on
29
+
30
+ Returns:
31
+ Initialized SybilHFWrapper model
32
+ """
33
+ print(f"Loading Sybil model on GPU {device_id}...")
34
+ config = SybilConfig()
35
+ model = SybilHFWrapper(config)
36
+
37
+ # Move model to specific GPU
38
+ device = torch.device(f'cuda:{device_id}')
39
+
40
+ # CRITICAL: Set the model's internal device attribute
41
+ # This ensures preprocessing moves data to the correct GPU
42
+ model.device = device
43
+
44
+ # Move all ensemble models to the correct GPU
45
+ for m in model.models:
46
+ m.to(device)
47
+ m.eval()
48
+
49
+ print(f"Model loaded successfully on GPU {device_id}!")
50
+ print(f" Model internal device: {model.device}")
51
+ return model, device
52
+
53
+ def is_localizer_scan(dicom_folder):
54
+ """
55
+ Check if a DICOM folder contains a localizer/scout scan.
56
+ Based on preprocessing.py logic.
57
+
58
+ Returns:
59
+ Tuple of (is_localizer, reason)
60
+ """
61
+ folder_path = Path(dicom_folder)
62
+ folder_name = folder_path.name.lower()
63
+ localizer_keywords = ['localizer', 'scout', 'topogram', 'surview', 'scanogram']
64
+
65
+ # Check folder name
66
+ if any(keyword in folder_name for keyword in localizer_keywords):
67
+ return True, f"Folder name contains localizer keyword: {folder_name}"
68
+
69
+ try:
70
+ dcm_files = list(folder_path.glob("*.dcm"))
71
+ if not dcm_files:
72
+ return False, "No DICOM files found"
73
+
74
+ # Check first few DICOM files for localizer metadata
75
+ sample_files = dcm_files[:min(3, len(dcm_files))]
76
+ for dcm_file in sample_files:
77
+ try:
78
+ dcm = pydicom.dcmread(str(dcm_file), stop_before_pixels=True)
79
+
80
+ # Check ImageType field
81
+ if hasattr(dcm, 'ImageType'):
82
+ image_type_str = ' '.join(str(val).lower() for val in dcm.ImageType)
83
+ if any(keyword in image_type_str for keyword in localizer_keywords):
84
+ return True, f"ImageType indicates localizer: {dcm.ImageType}"
85
+
86
+ # Check SeriesDescription field
87
+ if hasattr(dcm, 'SeriesDescription'):
88
+ if any(keyword in dcm.SeriesDescription.lower() for keyword in localizer_keywords):
89
+ return True, f"SeriesDescription indicates localizer: {dcm.SeriesDescription}"
90
+ except Exception as e:
91
+ continue
92
+ except Exception as e:
93
+ pass
94
+
95
+ return False, "Not a localizer scan"
96
+
97
+ def extract_timepoint_from_path(scan_dir):
98
+ """
99
+ Extract timepoint from scan directory path based on year.
100
+ 1999 -> T0, 2000 -> T1, 2001 -> T2, etc.
101
+
102
+ Looks for year patterns in folder names in date format MM-DD-YYYY.
103
+
104
+ Args:
105
+ scan_dir: Directory path string
106
+
107
+ Returns:
108
+ Timepoint string (e.g., 'T0', 'T1', 'T2') or None if not found
109
+ """
110
+ # Split path into components
111
+ path_parts = scan_dir.split('/')
112
+
113
+ # Look for date patterns like "01-02-2000-NLST-LSS"
114
+ # Pattern: Date format MM-DD-YYYY at the start of a folder name
115
+ date_pattern = r'^\d{2}-\d{2}-(19\d{2}|20\d{2})'
116
+
117
+ base_year = 1999
118
+
119
+ for part in path_parts:
120
+ # Check for date pattern (e.g., "01-02-2000-NLST-LSS-50335")
121
+ match = re.match(date_pattern, part)
122
+ if match:
123
+ year = int(match.group(1))
124
+ if 1999 <= year <= 2010: # Reasonable range for NLST
125
+ timepoint_num = year - base_year
126
+ print(f" DEBUG: Found year {year} in '{part}' -> T{timepoint_num}")
127
+ return f'T{timepoint_num}'
128
+
129
+ return None
130
+
131
+ def extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device):
132
+ """
133
+ Extract embedding from a single ensemble model.
134
+
135
+ Args:
136
+ model_idx: Index of the model in the ensemble
137
+ ensemble_model: Single model from the ensemble
138
+ pixel_values: Preprocessed pixel values tensor (already on correct device)
139
+ device: Device to run on (e.g., cuda:0, cuda:1)
140
+
141
+ Returns:
142
+ numpy array of embeddings from this model
143
+ """
144
+ embeddings_buffer = []
145
+
146
+ def create_hook(buffer):
147
+ def hook(module, input, output):
148
+ # Capture the output of ReLU layer (before dropout)
149
+ buffer.append(output.detach().cpu())
150
+ return hook
151
+
152
+ # Register hook on the ReLU layer (this is AFTER pooling, BEFORE dropout/classification)
153
+ hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer))
154
+
155
+ # Run forward pass on THIS model only with keyword argument
156
+ with torch.no_grad():
157
+ _ = ensemble_model(pixel_values=pixel_values)
158
+
159
+ # Remove hook
160
+ hook_handle.remove()
161
+
162
+ # Get the embeddings (should be shape [1, 512])
163
+ if embeddings_buffer:
164
+ embedding = embeddings_buffer[0].numpy().squeeze()
165
+ print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}")
166
+ return embedding
167
+ return None
168
+
169
+ def extract_embeddings(model, dicom_paths, device, use_parallel=True):
170
+ """
171
+ Extract embeddings from the layer after ReLU, before Dropout.
172
+ Processes ensemble models in parallel for speed.
173
+
174
+ Args:
175
+ model: Pre-loaded SybilHFWrapper model
176
+ dicom_paths: List of DICOM file paths
177
+ device: Device to run on (e.g., cuda:0, cuda:1)
178
+ use_parallel: If True, process ensemble models in parallel
179
+
180
+ Returns:
181
+ numpy array of shape (512,) - averaged embeddings across ensemble
182
+ """
183
+ # Preprocess ONCE (not 5 times!)
184
+ # The model's preprocessing handles moving data to the correct device
185
+ with torch.no_grad():
186
+ # Get the preprocessed input by calling the wrapper's preprocess_dicom method
187
+ # This returns the tensor that would be fed to each ensemble model
188
+ pixel_values = model.preprocess_dicom(dicom_paths)
189
+
190
+ if use_parallel:
191
+ # Process all ensemble models in parallel using ThreadPoolExecutor
192
+ all_embeddings = []
193
+
194
+ with ThreadPoolExecutor(max_workers=len(model.models)) as executor:
195
+ # Submit all models for parallel processing with the SAME preprocessed input
196
+ futures = [
197
+ executor.submit(extract_embedding_single_model, model_idx, ensemble_model, pixel_values, device)
198
+ for model_idx, ensemble_model in enumerate(model.models)
199
+ ]
200
+
201
+ # Collect results as they complete
202
+ for future in futures:
203
+ embedding = future.result()
204
+ if embedding is not None:
205
+ all_embeddings.append(embedding)
206
+ else:
207
+ # Sequential processing (original implementation)
208
+ all_embeddings = []
209
+ for model_idx, ensemble_model in enumerate(model.models):
210
+ embedding = extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device)
211
+ if embedding is not None:
212
+ all_embeddings.append(embedding)
213
+
214
+ # Average embeddings across ensemble
215
+ averaged_embedding = np.mean(all_embeddings, axis=0)
216
+ return averaged_embedding
217
+
218
+ def check_directory_for_dicoms(dirpath):
219
+ """
220
+ Check a single directory for valid DICOM files.
221
+ Returns (dirpath, num_files, subject_id, filter_reason) or None if invalid.
222
+ """
223
+ try:
224
+ # Quick check: does this directory have .dcm files?
225
+ dcm_files = [f for f in os.listdir(dirpath)
226
+ if f.endswith('.dcm') and os.path.isfile(os.path.join(dirpath, f))]
227
+
228
+ if not dcm_files:
229
+ return None
230
+
231
+ num_files = len(dcm_files)
232
+
233
+ # Filter out scans with 1-2 DICOM files (likely localizers)
234
+ if num_files <= 2:
235
+ return (dirpath, num_files, None, 'too_few_slices')
236
+
237
+ # Check if it's a localizer scan
238
+ is_loc, _ = is_localizer_scan(dirpath)
239
+ if is_loc:
240
+ return (dirpath, num_files, None, 'localizer')
241
+
242
+ # Extract subject ID (PID) from path
243
+ # Path structure: /NLST/<PID>/<date-info>/<scan-info>
244
+ # Example: /NLST/106639/01-02-1999-NLST-LSS-45699/1.000000-0OPLGEHSQXAnullna...
245
+ path_parts = dirpath.rstrip('/').split('/')
246
+
247
+ # Find the PID: it's the part after 'NLST' directory
248
+ try:
249
+ nlst_idx = path_parts.index('NLST')
250
+ subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST'
251
+ except (ValueError, IndexError):
252
+ # Fallback to old logic if path structure is different
253
+ subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
254
+
255
+ return (dirpath, num_files, subject_id, 'valid')
256
+
257
+ except Exception as e:
258
+ return None
259
+
260
+ def save_directory_cache(dicom_dirs, cache_file):
261
+ """
262
+ Save the list of DICOM directories to a cache file.
263
+
264
+ Args:
265
+ dicom_dirs: List of directory paths
266
+ cache_file: Path to cache file
267
+ """
268
+ print(f"\n💾 Saving directory cache to {cache_file}...")
269
+ cache_data = {
270
+ "timestamp": datetime.now().isoformat(),
271
+ "num_directories": len(dicom_dirs),
272
+ "directories": dicom_dirs
273
+ }
274
+ with open(cache_file, 'w') as f:
275
+ json.dump(cache_data, f, indent=2)
276
+ print(f"✓ Cache saved with {len(dicom_dirs)} directories\n")
277
+
278
+ def load_directory_cache(cache_file):
279
+ """
280
+ Load the list of DICOM directories from a cache file.
281
+
282
+ Args:
283
+ cache_file: Path to cache file
284
+
285
+ Returns:
286
+ List of directory paths, or None if cache doesn't exist or is invalid
287
+ """
288
+ if not os.path.exists(cache_file):
289
+ return None
290
+
291
+ try:
292
+ with open(cache_file, 'r') as f:
293
+ cache_data = json.load(f)
294
+
295
+ dicom_dirs = cache_data.get("directories", [])
296
+ timestamp = cache_data.get("timestamp", "unknown")
297
+
298
+ print(f"\n✓ Loaded directory cache from {cache_file}")
299
+ print(f" Cache created: {timestamp}")
300
+ print(f" Directories: {len(dicom_dirs)}\n")
301
+
302
+ return dicom_dirs
303
+ except Exception as e:
304
+ print(f"⚠️ Failed to load cache: {e}")
305
+ return None
306
+
307
+ def find_dicom_directories(root_dir, max_subjects=None, num_workers=12, cache_file=None, filter_pids=None):
308
+ """
309
+ Walk through directory tree and find all directories containing DICOM files.
310
+ Uses parallel processing for much faster scanning of large directory trees.
311
+ Only returns leaf directories (directories with .dcm files, not their parents).
312
+ Filters out localizer scans with 1-2 DICOM files.
313
+
314
+ Args:
315
+ root_dir: Root directory to search
316
+ max_subjects: Optional maximum number of unique subjects to process (None = all)
317
+ num_workers: Number of parallel workers for directory scanning (default: 12)
318
+ cache_file: Optional path to cache file for saving/loading directory list
319
+ filter_pids: Optional set of PIDs to filter (only include these subjects)
320
+
321
+ Returns:
322
+ List of directory paths containing .dcm files
323
+ """
324
+ # Try to load from cache first
325
+ if cache_file:
326
+ cached_dirs = load_directory_cache(cache_file)
327
+ if cached_dirs is not None:
328
+ print("✓ Using cached directory list (skipping scan)")
329
+
330
+ # Apply PID filter if specified
331
+ if filter_pids:
332
+ print(f" Filtering to {len(filter_pids)} PIDs from CSV...")
333
+ filtered_dirs = []
334
+ for d in cached_dirs:
335
+ # Extract PID from path: /NLST/<PID>/<date>/<scan>
336
+ path_parts = d.rstrip('/').split('/')
337
+ try:
338
+ nlst_idx = path_parts.index('NLST')
339
+ subject_id = path_parts[nlst_idx + 1]
340
+ except (ValueError, IndexError):
341
+ subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
342
+
343
+ if subject_id in filter_pids:
344
+ filtered_dirs.append(d)
345
+ print(f" ✓ Found {len(filtered_dirs)} scans matching PIDs")
346
+ return filtered_dirs
347
+
348
+ # Still apply max_subjects limit if specified
349
+ if max_subjects:
350
+ subjects_seen = set()
351
+ filtered_dirs = []
352
+ for d in cached_dirs:
353
+ # Extract PID from path: /NLST/<PID>/<date>/<scan>
354
+ path_parts = d.rstrip('/').split('/')
355
+ try:
356
+ nlst_idx = path_parts.index('NLST')
357
+ subject_id = path_parts[nlst_idx + 1]
358
+ except (ValueError, IndexError):
359
+ subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
360
+
361
+ # Check if we should include this scan
362
+ # Add scan if: (1) already collecting this subject, OR (2) under subject limit
363
+ if subject_id in subjects_seen:
364
+ # Already collecting this subject - add this scan
365
+ filtered_dirs.append(d)
366
+ elif len(subjects_seen) < max_subjects:
367
+ # New subject and under limit - start collecting this subject
368
+ subjects_seen.add(subject_id)
369
+ filtered_dirs.append(d)
370
+
371
+ # Stop once we have enough subjects
372
+ if len(subjects_seen) >= max_subjects:
373
+ # Count remaining scans from these subjects
374
+ remaining_count = 0
375
+ for remaining_d in cached_dirs[cached_dirs.index(d)+1:]:
376
+ remaining_parts = remaining_d.rstrip('/').split('/')
377
+ try:
378
+ remaining_nlst_idx = remaining_parts.index('NLST')
379
+ remaining_subject_id = remaining_parts[remaining_nlst_idx + 1]
380
+ except (ValueError, IndexError):
381
+ remaining_subject_id = remaining_parts[-3] if len(remaining_parts) >= 3 else remaining_parts[-1]
382
+ if remaining_subject_id in subjects_seen:
383
+ filtered_dirs.append(remaining_d)
384
+ break
385
+
386
+ print(f" ✓ Limited to {len(subjects_seen)} subjects ({len(filtered_dirs)} total scans)")
387
+ return filtered_dirs
388
+ return cached_dirs
389
+
390
+ print(f"Starting parallel directory scan with {num_workers} workers...")
391
+ if filter_pids:
392
+ print(f"⚡ FAST MODE: Only scanning {len(filter_pids)} PIDs (skipping others)")
393
+ else:
394
+ print("Scanning ALL subjects (this may take a while)")
395
+
396
+ # Phase 1: Fast parallel scan to find all directories with DICOM files
397
+ # BUT: Skip subject directories not in filter_pids for MASSIVE speedup
398
+ print("\nPhase 1: Scanning filesystem for DICOM directories...")
399
+ start_time = datetime.now()
400
+
401
+ # Collect all directories first (fast) - WITH EARLY FILTERING
402
+ all_dirs = []
403
+ for dirpath, dirnames, filenames in os.walk(root_dir):
404
+ # EARLY FILTER: If we have filter_pids, only descend into matching PID directories
405
+ if filter_pids:
406
+ path_parts = dirpath.rstrip('/').split('/')
407
+ try:
408
+ nlst_idx = path_parts.index('NLST')
409
+ # If this is a subject directory (one level below NLST)
410
+ if len(path_parts) == nlst_idx + 2:
411
+ subject_id = path_parts[nlst_idx + 1]
412
+ # Skip this subject if not in filter list
413
+ if subject_id not in filter_pids:
414
+ dirnames.clear() # Don't descend into this subject's subdirs
415
+ continue
416
+ except (ValueError, IndexError):
417
+ pass
418
+
419
+ # Quick check: if directory has .dcm files, add to list
420
+ if any(f.endswith('.dcm') for f in filenames):
421
+ all_dirs.append(dirpath)
422
+
423
+ print(f"Found {len(all_dirs)} potential DICOM directories in {(datetime.now() - start_time).total_seconds():.1f}s")
424
+
425
+ # Phase 2: Parallel validation and filtering
426
+ print(f"\nPhase 2: Validating directories in parallel ({num_workers} workers)...")
427
+
428
+ from concurrent.futures import ProcessPoolExecutor, as_completed
429
+
430
+ dicom_dirs = []
431
+ subjects_found = set()
432
+ filtered_stats = {'localizers': 0, 'too_few_slices': 0}
433
+
434
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
435
+ # Submit all directories for checking
436
+ future_to_dir = {executor.submit(check_directory_for_dicoms, d): d for d in all_dirs}
437
+
438
+ # Process results as they complete
439
+ for i, future in enumerate(as_completed(future_to_dir), 1):
440
+ # Print progress every 1000 dirs (more frequent for visibility)
441
+ if i % 1000 == 0:
442
+ elapsed = (datetime.now() - start_time).total_seconds()
443
+ rate = i / elapsed if elapsed > 0 else 0
444
+ remaining = (len(all_dirs) - i) / rate if rate > 0 else 0
445
+ print(f" [{i}/{len(all_dirs)}] Found: {len(dicom_dirs)} scans from {len(subjects_found)} PIDs | "
446
+ f"Filtered: {filtered_stats['localizers'] + filtered_stats['too_few_slices']} | "
447
+ f"ETA: {remaining/60:.1f} min")
448
+
449
+ try:
450
+ result = future.result()
451
+ if result is None:
452
+ continue
453
+
454
+ dirpath, num_files, subject_id, status = result
455
+
456
+ if status == 'too_few_slices':
457
+ filtered_stats['too_few_slices'] += 1
458
+ elif status == 'localizer':
459
+ filtered_stats['localizers'] += 1
460
+ elif status == 'valid':
461
+ # Check PID filter
462
+ if filter_pids is not None and subject_id not in filter_pids:
463
+ continue
464
+
465
+ # Check subject limit
466
+ if max_subjects is not None and subject_id not in subjects_found and len(subjects_found) >= max_subjects:
467
+ continue
468
+
469
+ subjects_found.add(subject_id)
470
+ dicom_dirs.append(dirpath)
471
+
472
+ # Print when we find a new PID match (helpful for filtered runs)
473
+ if filter_pids and len(dicom_dirs) % 100 == 1:
474
+ print(f" ✓ Found {len(dicom_dirs)} scans so far ({len(subjects_found)} unique PIDs)")
475
+
476
+ # Stop if we've hit subject limit
477
+ if max_subjects is not None and len(subjects_found) >= max_subjects:
478
+ print(f"\n✓ Reached limit of {max_subjects} subjects. Stopping search.")
479
+ # Cancel remaining futures
480
+ for f in future_to_dir:
481
+ f.cancel()
482
+ break
483
+
484
+ except Exception as e:
485
+ continue
486
+
487
+ scan_time = (datetime.now() - start_time).total_seconds()
488
+
489
+ print(f"\n{'='*80}")
490
+ print(f"Directory Scan Complete in {scan_time:.1f}s ({scan_time/60:.1f} minutes)")
491
+ print(f"{'='*80}")
492
+ print(f"Filtering Summary:")
493
+ print(f" ✅ Valid scans found: {len(dicom_dirs)}")
494
+ print(f" 🚫 Localizers filtered: {filtered_stats['localizers']}")
495
+ print(f" ⏭️ Too few slices (≤2) filtered: {filtered_stats['too_few_slices']}")
496
+ print(f" 📊 Unique subjects: {len(subjects_found)}")
497
+ print(f" ⚡ Speed: {len(all_dirs)/scan_time:.0f} dirs/second")
498
+ print(f"{'='*80}\n")
499
+
500
+ # Save to cache if specified
501
+ if cache_file:
502
+ save_directory_cache(dicom_dirs, cache_file)
503
+
504
+ return dicom_dirs
505
+
506
+ def prepare_scan_metadata(scan_dir):
507
+ """
508
+ Prepare metadata for a scan without processing.
509
+
510
+ Args:
511
+ scan_dir: Directory containing DICOM files for one scan
512
+
513
+ Returns:
514
+ tuple: (dicom_file_paths, num_files, subject_id, scan_id)
515
+ """
516
+ # Count DICOM files (ensure they are actual files, not directories)
517
+ dicom_files = [f for f in os.listdir(scan_dir)
518
+ if f.endswith('.dcm') and os.path.isfile(os.path.join(scan_dir, f))]
519
+ num_dicom_files = len(dicom_files)
520
+
521
+ if num_dicom_files == 0:
522
+ raise ValueError("No valid DICOM files found")
523
+
524
+ # Create list of full paths to DICOM files
525
+ dicom_file_paths = [os.path.join(scan_dir, f) for f in dicom_files]
526
+
527
+ # Parse directory path to extract identifiers
528
+ # Path structure: /NLST/<PID>/<date-info>/<scan-info>
529
+ path_parts = scan_dir.rstrip('/').split('/')
530
+ scan_id = path_parts[-1] if path_parts[-1] else path_parts[-2]
531
+
532
+ # Extract PID from path
533
+ try:
534
+ nlst_idx = path_parts.index('NLST')
535
+ subject_id = path_parts[nlst_idx + 1] # PID is right after 'NLST'
536
+ except (ValueError, IndexError):
537
+ # Fallback to old logic
538
+ subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1]
539
+
540
+ return dicom_file_paths, num_dicom_files, subject_id, scan_id
541
+
542
+ def save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_num):
543
+ """
544
+ Save a checkpoint of embeddings and metadata.
545
+
546
+ Args:
547
+ all_embeddings: List of embedding arrays
548
+ all_metadata: List of metadata dictionaries
549
+ failed: List of failed scans
550
+ output_dir: Output directory
551
+ checkpoint_num: Checkpoint number
552
+ """
553
+ print(f"\n💾 Saving checkpoint {checkpoint_num}...")
554
+
555
+ # Convert embeddings to array
556
+ embeddings_array = np.array(all_embeddings)
557
+ embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0])
558
+
559
+ # Create DataFrame
560
+ df_data = {
561
+ 'case_number': [m['case_number'] for m in all_metadata],
562
+ 'subject_id': [m['subject_id'] for m in all_metadata],
563
+ 'scan_id': [m['scan_id'] for m in all_metadata],
564
+ 'timepoint': [m.get('timepoint') for m in all_metadata],
565
+ 'dicom_directory': [m['dicom_directory'] for m in all_metadata],
566
+ 'num_dicom_files': [m['num_dicom_files'] for m in all_metadata],
567
+ 'embedding_index': [m['embedding_index'] for m in all_metadata],
568
+ 'embedding': list(embeddings_array)
569
+ }
570
+ df = pd.DataFrame(df_data)
571
+
572
+ # Save checkpoint parquet
573
+ checkpoint_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_embeddings.parquet")
574
+ df.to_parquet(checkpoint_path, index=False, compression='snappy')
575
+ print(f" ✓ Saved embeddings checkpoint: {checkpoint_path}")
576
+
577
+ # Save checkpoint metadata
578
+ checkpoint_metadata = {
579
+ "checkpoint_num": checkpoint_num,
580
+ "timestamp": datetime.now().isoformat(),
581
+ "total_scans": len(all_embeddings),
582
+ "failed_scans": len(failed),
583
+ "embedding_shape": list(embeddings_array.shape),
584
+ "scans": all_metadata,
585
+ "failed_scans": failed
586
+ }
587
+ metadata_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_metadata.json")
588
+ with open(metadata_path, 'w') as f:
589
+ json.dump(checkpoint_metadata, f, indent=2)
590
+ print(f" ✓ Saved metadata checkpoint: {metadata_path}")
591
+ print(f"💾 Checkpoint {checkpoint_num} complete!\n")
592
+
593
+ def process_scan(model, device, scan_dir):
594
+ """
595
+ Process a single scan directory and extract embeddings.
596
+
597
+ Args:
598
+ model: Pre-loaded SybilHFWrapper model
599
+ device: Device to run on (e.g., cuda:0, cuda:1)
600
+ scan_dir: Directory containing DICOM files for one scan
601
+
602
+ Returns:
603
+ tuple: (embeddings, scan_metadata)
604
+ """
605
+ dicom_file_paths, num_dicom_files, subject_id, scan_id = prepare_scan_metadata(scan_dir)
606
+
607
+ print(f"\nProcessing: {scan_dir}")
608
+ print(f"DICOM files: {num_dicom_files}")
609
+
610
+ # Extract embeddings
611
+ embeddings = extract_embeddings(model, dicom_file_paths, device)
612
+
613
+ print(f"Embedding shape: {embeddings.shape}")
614
+
615
+ # Extract timepoint from path (e.g., 1999 -> T0, 2000 -> T1)
616
+ timepoint = extract_timepoint_from_path(scan_dir)
617
+ if timepoint:
618
+ print(f"Timepoint: {timepoint}")
619
+ else:
620
+ print(f"Timepoint: Not detected")
621
+
622
+ # Create metadata for this scan
623
+ scan_metadata = {
624
+ "case_number": subject_id, # Case number (e.g., 205749)
625
+ "subject_id": subject_id,
626
+ "scan_id": scan_id,
627
+ "timepoint": timepoint, # T0, T1, T2, etc. or None
628
+ "dicom_directory": scan_dir,
629
+ "num_dicom_files": num_dicom_files,
630
+ "embedding_index": None, # Will be set later
631
+ "statistics": {
632
+ "mean": float(np.mean(embeddings)),
633
+ "std": float(np.std(embeddings)),
634
+ "min": float(np.min(embeddings)),
635
+ "max": float(np.max(embeddings))
636
+ }
637
+ }
638
+
639
+ return embeddings, scan_metadata
640
+
641
+ # Main execution
642
+ if __name__ == "__main__":
643
+ import argparse
644
+
645
+ # Parse command line arguments
646
+ parser = argparse.ArgumentParser(description='Extract Sybil embeddings from DICOM scans')
647
+
648
+ # Input/Output
649
+ parser.add_argument('--root-dir', type=str, required=True,
650
+ help='Root directory containing DICOM files (e.g., /path/to/NLST)')
651
+ parser.add_argument('--pid-csv', type=str, default=None,
652
+ help='CSV file with "pid" column to filter subjects (e.g., subsets/hybridModels-train.csv)')
653
+ parser.add_argument('--output-dir', type=str, default='embeddings_output',
654
+ help='Output directory for embeddings (default: embeddings_output)')
655
+ parser.add_argument('--max-subjects', type=int, default=None,
656
+ help='Maximum number of subjects to process (for testing)')
657
+
658
+ # Performance tuning
659
+ parser.add_argument('--num-gpus', type=int, default=1,
660
+ help='Number of GPUs to use (default: 1)')
661
+ parser.add_argument('--num-parallel', type=int, default=1,
662
+ help='Number of parallel scans to process simultaneously (default: 1, recommended: 1-4 depending on GPU memory)')
663
+ parser.add_argument('--num-workers', type=int, default=4,
664
+ help='Number of parallel workers for directory scanning (default: 4, recommended: 4-12 depending on storage speed)')
665
+ parser.add_argument('--checkpoint-interval', type=int, default=1000,
666
+ help='Save checkpoint every N scans (default: 1000)')
667
+
668
+ args = parser.parse_args()
669
+
670
+ # ==========================================
671
+ # CONFIGURATION
672
+ # ==========================================
673
+ root_dir = args.root_dir
674
+ output_dir = args.output_dir
675
+ max_subjects = args.max_subjects
676
+ num_gpus = args.num_gpus
677
+ num_parallel_scans = args.num_parallel
678
+ num_scan_workers = args.num_workers
679
+ checkpoint_interval = args.checkpoint_interval
680
+
681
+ # Always use the main cache file from the full run
682
+ main_cache = "embeddings_output_full/directory_cache.json"
683
+ if os.path.exists(main_cache):
684
+ cache_file = main_cache
685
+ print(f"✓ Found main directory cache: {main_cache}")
686
+ else:
687
+ cache_file = os.path.join(output_dir, "directory_cache.json")
688
+
689
+ # Verify root directory exists
690
+ if not os.path.exists(root_dir):
691
+ raise ValueError(f"Root directory does not exist: {root_dir}")
692
+
693
+ # Load PIDs from CSV if provided
694
+ filter_pids = None
695
+ if args.pid_csv:
696
+ print(f"Loading subject PIDs from: {args.pid_csv}")
697
+ import pandas as pd
698
+ csv_data = pd.read_csv(args.pid_csv)
699
+ filter_pids = set(str(pid) for pid in csv_data['pid'].unique())
700
+ print(f" Found {len(filter_pids)} unique PIDs to extract")
701
+ print(f" Examples: {list(filter_pids)[:5]}")
702
+
703
+ # Create output directory
704
+ os.makedirs(output_dir, exist_ok=True)
705
+
706
+ # Print configuration
707
+ print(f"\n{'='*80}")
708
+ print(f"CONFIGURATION")
709
+ print(f"{'='*80}")
710
+ print(f"Root directory: {root_dir}")
711
+ print(f"Output directory: {output_dir}")
712
+ print(f"Number of GPUs: {num_gpus}")
713
+ print(f"Parallel scans: {num_parallel_scans} (recommended: 1-4 depending on GPU memory)")
714
+ print(f"Directory scan workers: {num_scan_workers} (recommended: 4-12 depending on storage)")
715
+ print(f"Checkpoint interval: {checkpoint_interval} scans")
716
+ if filter_pids:
717
+ print(f"Filtering to: {len(filter_pids)} PIDs from CSV")
718
+ if max_subjects:
719
+ print(f"Max subjects: {max_subjects}")
720
+ print(f"{'='*80}\n")
721
+
722
+ # Warning about memory requirements
723
+ if num_parallel_scans > 1:
724
+ estimated_vram = (num_parallel_scans // num_gpus) * 10
725
+ print(f"⚠️ MEMORY WARNING:")
726
+ print(f" Parallel processing requires ~{estimated_vram}GB VRAM per GPU")
727
+ print(f" If you encounter OOM errors, reduce --num-parallel to 1-2")
728
+ print(f" Current: {num_parallel_scans} scans across {num_gpus} GPU(s)\n")
729
+
730
+ # Find all directories containing DICOM files (FAST with parallel processing!)
731
+ # Will use cached directory list if available, otherwise scan and save cache
732
+ dicom_dirs = find_dicom_directories(root_dir, max_subjects=max_subjects,
733
+ num_workers=num_scan_workers, cache_file=cache_file,
734
+ filter_pids=filter_pids)
735
+
736
+ if len(dicom_dirs) == 0:
737
+ raise ValueError(f"No directories with DICOM files found in {root_dir}")
738
+
739
+ print(f"\n{'='*80}")
740
+ print(f"Found {len(dicom_dirs)} directories containing DICOM files")
741
+ print(f"{'='*80}\n")
742
+
743
+ # Detect and load models on multiple GPUs
744
+ print(f"🎮 Detected {num_gpus} GPU(s)")
745
+ print(f"🚀 Will process {num_parallel_scans} scans in parallel ({num_parallel_scans // num_gpus} per GPU)")
746
+ print(f"💾 Checkpoints will be saved every {checkpoint_interval} scans\n")
747
+
748
+ # Load models on each GPU
749
+ models_and_devices = []
750
+ for gpu_id in range(num_gpus):
751
+ model, device = load_model(gpu_id)
752
+ models_and_devices.append((model, device, gpu_id))
753
+
754
+ # Process each scan directory and collect all embeddings
755
+ all_embeddings = []
756
+ all_metadata = []
757
+ failed = []
758
+ checkpoint_counter = 0
759
+
760
+ if num_parallel_scans > 1:
761
+ # Parallel processing of multiple scans across multiple GPUs
762
+ print(f"Processing {num_parallel_scans} scans in parallel across {num_gpus} GPU(s)...")
763
+ print(f"Note: This requires ~{(num_parallel_scans // num_gpus) * 10}GB VRAM per GPU.\n")
764
+
765
+ from functools import partial
766
+ from concurrent.futures import as_completed
767
+
768
+ # Process scans in batches for checkpoint saving
769
+ batch_size = checkpoint_interval
770
+ num_batches = (len(dicom_dirs) + batch_size - 1) // batch_size
771
+
772
+ for batch_idx in range(num_batches):
773
+ start_idx = batch_idx * batch_size
774
+ end_idx = min(start_idx + batch_size, len(dicom_dirs))
775
+ batch_dirs = dicom_dirs[start_idx:end_idx]
776
+
777
+ print(f"\n{'='*80}")
778
+ print(f"Processing batch {batch_idx + 1}/{num_batches} (scans {start_idx + 1} to {end_idx})")
779
+ print(f"{'='*80}\n")
780
+
781
+ # Use ThreadPoolExecutor for parallel scan processing
782
+ # IMPORTANT: max_workers limits concurrent execution to prevent OOM
783
+ with ThreadPoolExecutor(max_workers=num_parallel_scans) as executor:
784
+ # Submit scans in controlled batches to avoid memory issues
785
+ # We submit only max_workers scans at once, then submit more as they complete
786
+ future_to_info = {}
787
+ scan_queue = list(enumerate(batch_dirs))
788
+ scans_submitted = 0
789
+
790
+ # Submit initial batch (up to max_workers scans)
791
+ while scan_queue and scans_submitted < num_parallel_scans:
792
+ i, scan_dir = scan_queue.pop(0)
793
+ # Select GPU in round-robin fashion
794
+ gpu_idx = i % num_gpus
795
+ model, device, gpu_id = models_and_devices[gpu_idx]
796
+
797
+ # Create partial function with model and device
798
+ process_func = partial(process_scan, model, device)
799
+ future = executor.submit(process_func, scan_dir)
800
+ future_to_info[future] = (start_idx + i + 1, scan_dir, gpu_id)
801
+ scans_submitted += 1
802
+
803
+ # Process results as they complete and submit new scans
804
+ while future_to_info:
805
+ # Wait for next completion
806
+ done_futures = []
807
+ for future in list(future_to_info.keys()):
808
+ if future.done():
809
+ done_futures.append(future)
810
+
811
+ if not done_futures:
812
+ import time
813
+ time.sleep(0.1)
814
+ continue
815
+
816
+ # Process completed futures
817
+ for future in done_futures:
818
+ scan_num, scan_dir, gpu_id = future_to_info.pop(future)
819
+ try:
820
+ print(f"[{scan_num}/{len(dicom_dirs)}] Processing on GPU {gpu_id}...")
821
+ embeddings, scan_metadata = future.result()
822
+
823
+ # Set the index for this scan
824
+ scan_metadata["embedding_index"] = len(all_embeddings)
825
+
826
+ # Collect embeddings and metadata
827
+ all_embeddings.append(embeddings)
828
+ all_metadata.append(scan_metadata)
829
+
830
+ except Exception as e:
831
+ print(f"ERROR processing {scan_dir}: {e}")
832
+ failed.append({"scan_dir": scan_dir, "error": str(e)})
833
+
834
+ # Submit next scan from queue
835
+ if scan_queue:
836
+ i, next_scan_dir = scan_queue.pop(0)
837
+ gpu_idx = i % num_gpus
838
+ model, device, gpu_id = models_and_devices[gpu_idx]
839
+
840
+ process_func = partial(process_scan, model, device)
841
+ new_future = executor.submit(process_func, next_scan_dir)
842
+ future_to_info[new_future] = (start_idx + i + 1, next_scan_dir, gpu_id)
843
+
844
+ # Save checkpoint after each batch
845
+ checkpoint_counter += 1
846
+ save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter)
847
+
848
+ print(f"Progress: {len(all_embeddings)}/{len(dicom_dirs)} scans completed "
849
+ f"({len(all_embeddings)/len(dicom_dirs)*100:.1f}%)\n")
850
+ else:
851
+ # Sequential processing (original behavior)
852
+ model, device, gpu_id = models_and_devices[0] # Use first GPU
853
+
854
+ for i, scan_dir in enumerate(dicom_dirs, 1):
855
+ try:
856
+ print(f"\n[{i}/{len(dicom_dirs)}] Processing scan...")
857
+
858
+ # Process scan and get results
859
+ embeddings, scan_metadata = process_scan(model, device, scan_dir)
860
+
861
+ # Set the index for this scan
862
+ scan_metadata["embedding_index"] = len(all_embeddings)
863
+
864
+ # Collect embeddings and metadata
865
+ all_embeddings.append(embeddings)
866
+ all_metadata.append(scan_metadata)
867
+
868
+ # Save checkpoint every checkpoint_interval scans
869
+ if i % checkpoint_interval == 0:
870
+ checkpoint_counter += 1
871
+ save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter)
872
+
873
+ except Exception as e:
874
+ print(f"ERROR processing {scan_dir}: {e}")
875
+ failed.append({"scan_dir": scan_dir, "error": str(e)})
876
+
877
+ # Convert embeddings list to numpy array
878
+ # Shape will be (num_scans, embedding_dim)
879
+ embeddings_array = np.array(all_embeddings)
880
+ embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0])
881
+
882
+ # Create DataFrame with embeddings and metadata for Parquet
883
+ # Store embeddings as a single array column
884
+ df_data = {
885
+ 'case_number': [m['case_number'] for m in all_metadata],
886
+ 'subject_id': [m['subject_id'] for m in all_metadata],
887
+ 'scan_id': [m['scan_id'] for m in all_metadata],
888
+ 'timepoint': [m.get('timepoint') for m in all_metadata], # T0, T1, T2, etc.
889
+ 'dicom_directory': [m['dicom_directory'] for m in all_metadata],
890
+ 'num_dicom_files': [m['num_dicom_files'] for m in all_metadata],
891
+ 'embedding_index': [m['embedding_index'] for m in all_metadata],
892
+ 'embedding': list(embeddings_array) # Store as list of arrays
893
+ }
894
+
895
+ # Create DataFrame
896
+ df = pd.DataFrame(df_data)
897
+
898
+ # Save final complete file as Parquet
899
+ embeddings_filename = "all_embeddings.parquet"
900
+ embeddings_path = os.path.join(output_dir, embeddings_filename)
901
+ df.to_parquet(embeddings_path, index=False, compression='snappy')
902
+ print(f"\n✅ Saved FINAL embeddings to Parquet: {embeddings_path}")
903
+
904
+ # Create comprehensive metadata JSON
905
+ dataset_metadata = {
906
+ "dataset_info": {
907
+ "root_directory": root_dir,
908
+ "total_scans": len(all_embeddings),
909
+ "failed_scans": len(failed),
910
+ "embedding_shape": list(embeddings_array.shape),
911
+ "embedding_dim": embedding_dim,
912
+ "extraction_timestamp": datetime.now().isoformat(),
913
+ "file_format": "parquet"
914
+ },
915
+ "model_info": {
916
+ "model": "Lab-Rasool/sybil",
917
+ "layer": "after_relu_before_dropout",
918
+ "ensemble_averaged": True,
919
+ "num_ensemble_models": 5
920
+ },
921
+ "embeddings_file": embeddings_filename,
922
+ "parquet_schema": {
923
+ "metadata_columns": ["case_number", "subject_id", "scan_id", "timepoint", "dicom_directory", "num_dicom_files", "embedding_index"],
924
+ "embedding_column": "embedding",
925
+ "embedding_shape": f"({embedding_dim},)",
926
+ "total_columns": 8,
927
+ "timepoint_info": "T0=1999, T1=2000, T2=2001, etc. Extracted from year in path. Can be None if not detected."
928
+ },
929
+ "filtering_info": {
930
+ "localizer_detection": "Scans identified as localizers (by folder name or DICOM metadata) are filtered out",
931
+ "min_slices": "Scans with ≤2 DICOM files are filtered out (likely localizers)",
932
+ "accepted_scans": len(all_embeddings)
933
+ },
934
+ "scans": all_metadata,
935
+ "failed_scans": failed
936
+ }
937
+
938
+ metadata_filename = "dataset_metadata.json"
939
+ metadata_path = os.path.join(output_dir, metadata_filename)
940
+ with open(metadata_path, 'w') as f:
941
+ json.dump(dataset_metadata, f, indent=2)
942
+ print(f"✅ Saved FINAL metadata: {metadata_path}")
943
+
944
+ # Summary
945
+ print(f"\n{'='*80}")
946
+ print(f"PROCESSING COMPLETE")
947
+ print(f"{'='*80}")
948
+ print(f"Successfully processed: {len(all_embeddings)}/{len(dicom_dirs)} scans")
949
+ print(f"Failed: {len(failed)}/{len(dicom_dirs)} scans")
950
+ print(f"\nEmbeddings array shape: {embeddings_array.shape}")
951
+ print(f"Saved embeddings to: {embeddings_path}")
952
+ print(f"Saved metadata to: {metadata_path}")
953
+
954
+ # Timepoint summary
955
+ timepoint_counts = {}
956
+ for m in all_metadata:
957
+ tp = m.get('timepoint', 'Unknown')
958
+ timepoint_counts[tp] = timepoint_counts.get(tp, 0) + 1
959
+
960
+ if timepoint_counts:
961
+ print(f"\n📅 Timepoint Distribution:")
962
+ for tp in sorted(timepoint_counts.keys(), key=lambda x: (x is None, x)):
963
+ count = timepoint_counts[tp]
964
+ if tp is None:
965
+ print(f" Unknown/Not detected: {count} scans")
966
+ else:
967
+ print(f" {tp}: {count} scans")
968
+
969
+ if failed:
970
+ print(f"\nFailed scans: {len(failed)}")
971
+ for fail_info in failed[:5]: # Show first 5 failures
972
+ print(f" - {fail_info['scan_dir']}")
973
+ print(f" Error: {fail_info['error']}")
974
+ if len(failed) > 5:
975
+ print(f" ... and {len(failed) - 5} more failures")
976
+
977
+ print(f"\n{'='*80}")
978
+ print(f"For downstream training, load embeddings with:")
979
+ print(f" import pandas as pd")
980
+ print(f" import numpy as np")
981
+ print(f" df = pd.read_parquet('{embeddings_path}')")
982
+ print(f" # Total rows: {len(df)}, Total columns: {len(df.columns)}")
983
+ print(f" # Extract embeddings array: embeddings = np.stack(df['embedding'].values)")
984
+ print(f" # Shape: {embeddings_array.shape}")
985
+ print(f" # Access individual: df.loc[0, 'embedding'] -> array of shape ({embedding_dim},)")
986
+ print(f"{'='*80}")