|
|
from huggingface_hub import snapshot_download |
|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import json |
|
|
import re |
|
|
import pydicom |
|
|
from datetime import datetime |
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor |
|
|
from pathlib import Path |
|
|
import threading |
|
|
import multiprocessing as mp |
|
|
|
|
|
|
|
|
model_path = snapshot_download(repo_id="Lab-Rasool/sybil") |
|
|
sys.path.append(model_path) |
|
|
|
|
|
from modeling_sybil_hf import SybilHFWrapper |
|
|
from configuration_sybil import SybilConfig |
|
|
|
|
|
def load_model(device_id=0): |
|
|
""" |
|
|
Load and initialize the Sybil model once. |
|
|
|
|
|
Args: |
|
|
device_id: GPU device ID to load model on |
|
|
|
|
|
Returns: |
|
|
Initialized SybilHFWrapper model |
|
|
""" |
|
|
print(f"Loading Sybil model on GPU {device_id}...") |
|
|
config = SybilConfig() |
|
|
model = SybilHFWrapper(config) |
|
|
|
|
|
|
|
|
device = torch.device(f'cuda:{device_id}') |
|
|
|
|
|
|
|
|
|
|
|
model.device = device |
|
|
|
|
|
|
|
|
for m in model.models: |
|
|
m.to(device) |
|
|
m.eval() |
|
|
|
|
|
print(f"Model loaded successfully on GPU {device_id}!") |
|
|
print(f" Model internal device: {model.device}") |
|
|
return model, device |
|
|
|
|
|
def is_localizer_scan(dicom_folder): |
|
|
""" |
|
|
Check if a DICOM folder contains a localizer/scout scan. |
|
|
Based on preprocessing.py logic. |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_localizer, reason) |
|
|
""" |
|
|
folder_path = Path(dicom_folder) |
|
|
folder_name = folder_path.name.lower() |
|
|
localizer_keywords = ['localizer', 'scout', 'topogram', 'surview', 'scanogram'] |
|
|
|
|
|
|
|
|
if any(keyword in folder_name for keyword in localizer_keywords): |
|
|
return True, f"Folder name contains localizer keyword: {folder_name}" |
|
|
|
|
|
try: |
|
|
dcm_files = list(folder_path.glob("*.dcm")) |
|
|
if not dcm_files: |
|
|
return False, "No DICOM files found" |
|
|
|
|
|
|
|
|
sample_files = dcm_files[:min(3, len(dcm_files))] |
|
|
for dcm_file in sample_files: |
|
|
try: |
|
|
dcm = pydicom.dcmread(str(dcm_file), stop_before_pixels=True) |
|
|
|
|
|
|
|
|
if hasattr(dcm, 'ImageType'): |
|
|
image_type_str = ' '.join(str(val).lower() for val in dcm.ImageType) |
|
|
if any(keyword in image_type_str for keyword in localizer_keywords): |
|
|
return True, f"ImageType indicates localizer: {dcm.ImageType}" |
|
|
|
|
|
|
|
|
if hasattr(dcm, 'SeriesDescription'): |
|
|
if any(keyword in dcm.SeriesDescription.lower() for keyword in localizer_keywords): |
|
|
return True, f"SeriesDescription indicates localizer: {dcm.SeriesDescription}" |
|
|
except Exception as e: |
|
|
continue |
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
return False, "Not a localizer scan" |
|
|
|
|
|
def extract_timepoint_from_path(scan_dir): |
|
|
""" |
|
|
Extract timepoint from scan directory path based on year. |
|
|
1999 -> T0, 2000 -> T1, 2001 -> T2, etc. |
|
|
|
|
|
Looks for year patterns in folder names in date format MM-DD-YYYY. |
|
|
|
|
|
Args: |
|
|
scan_dir: Directory path string |
|
|
|
|
|
Returns: |
|
|
Timepoint string (e.g., 'T0', 'T1', 'T2') or None if not found |
|
|
""" |
|
|
|
|
|
path_parts = scan_dir.split('/') |
|
|
|
|
|
|
|
|
|
|
|
date_pattern = r'^\d{2}-\d{2}-(19\d{2}|20\d{2})' |
|
|
|
|
|
base_year = 1999 |
|
|
|
|
|
for part in path_parts: |
|
|
|
|
|
match = re.match(date_pattern, part) |
|
|
if match: |
|
|
year = int(match.group(1)) |
|
|
if 1999 <= year <= 2010: |
|
|
timepoint_num = year - base_year |
|
|
print(f" DEBUG: Found year {year} in '{part}' -> T{timepoint_num}") |
|
|
return f'T{timepoint_num}' |
|
|
|
|
|
return None |
|
|
|
|
|
def extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device): |
|
|
""" |
|
|
Extract embedding from a single ensemble model. |
|
|
|
|
|
Args: |
|
|
model_idx: Index of the model in the ensemble |
|
|
ensemble_model: Single model from the ensemble |
|
|
pixel_values: Preprocessed pixel values tensor (already on correct device) |
|
|
device: Device to run on (e.g., cuda:0, cuda:1) |
|
|
|
|
|
Returns: |
|
|
numpy array of embeddings from this model |
|
|
""" |
|
|
embeddings_buffer = [] |
|
|
|
|
|
def create_hook(buffer): |
|
|
def hook(module, input, output): |
|
|
|
|
|
buffer.append(output.detach().cpu()) |
|
|
return hook |
|
|
|
|
|
|
|
|
hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer)) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_ = ensemble_model(pixel_values=pixel_values) |
|
|
|
|
|
|
|
|
hook_handle.remove() |
|
|
|
|
|
|
|
|
if embeddings_buffer: |
|
|
embedding = embeddings_buffer[0].numpy().squeeze() |
|
|
print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}") |
|
|
return embedding |
|
|
return None |
|
|
|
|
|
def extract_embeddings(model, dicom_paths, device, use_parallel=True): |
|
|
""" |
|
|
Extract embeddings from the layer after ReLU, before Dropout. |
|
|
Processes ensemble models in parallel for speed. |
|
|
|
|
|
Args: |
|
|
model: Pre-loaded SybilHFWrapper model |
|
|
dicom_paths: List of DICOM file paths |
|
|
device: Device to run on (e.g., cuda:0, cuda:1) |
|
|
use_parallel: If True, process ensemble models in parallel |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (512,) - averaged embeddings across ensemble |
|
|
""" |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
pixel_values = model.preprocess_dicom(dicom_paths) |
|
|
|
|
|
if use_parallel: |
|
|
|
|
|
all_embeddings = [] |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=len(model.models)) as executor: |
|
|
|
|
|
futures = [ |
|
|
executor.submit(extract_embedding_single_model, model_idx, ensemble_model, pixel_values, device) |
|
|
for model_idx, ensemble_model in enumerate(model.models) |
|
|
] |
|
|
|
|
|
|
|
|
for future in futures: |
|
|
embedding = future.result() |
|
|
if embedding is not None: |
|
|
all_embeddings.append(embedding) |
|
|
else: |
|
|
|
|
|
all_embeddings = [] |
|
|
for model_idx, ensemble_model in enumerate(model.models): |
|
|
embedding = extract_embedding_single_model(model_idx, ensemble_model, pixel_values, device) |
|
|
if embedding is not None: |
|
|
all_embeddings.append(embedding) |
|
|
|
|
|
|
|
|
averaged_embedding = np.mean(all_embeddings, axis=0) |
|
|
return averaged_embedding |
|
|
|
|
|
def check_directory_for_dicoms(dirpath): |
|
|
""" |
|
|
Check a single directory for valid DICOM files. |
|
|
Returns (dirpath, num_files, subject_id, filter_reason) or None if invalid. |
|
|
""" |
|
|
try: |
|
|
|
|
|
dcm_files = [f for f in os.listdir(dirpath) |
|
|
if f.endswith('.dcm') and os.path.isfile(os.path.join(dirpath, f))] |
|
|
|
|
|
if not dcm_files: |
|
|
return None |
|
|
|
|
|
num_files = len(dcm_files) |
|
|
|
|
|
|
|
|
if num_files <= 2: |
|
|
return (dirpath, num_files, None, 'too_few_slices') |
|
|
|
|
|
|
|
|
is_loc, _ = is_localizer_scan(dirpath) |
|
|
if is_loc: |
|
|
return (dirpath, num_files, None, 'localizer') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path_parts = dirpath.rstrip('/').split('/') |
|
|
|
|
|
|
|
|
try: |
|
|
nlst_idx = path_parts.index('NLST') |
|
|
subject_id = path_parts[nlst_idx + 1] |
|
|
except (ValueError, IndexError): |
|
|
|
|
|
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] |
|
|
|
|
|
return (dirpath, num_files, subject_id, 'valid') |
|
|
|
|
|
except Exception as e: |
|
|
return None |
|
|
|
|
|
def save_directory_cache(dicom_dirs, cache_file): |
|
|
""" |
|
|
Save the list of DICOM directories to a cache file. |
|
|
|
|
|
Args: |
|
|
dicom_dirs: List of directory paths |
|
|
cache_file: Path to cache file |
|
|
""" |
|
|
print(f"\n💾 Saving directory cache to {cache_file}...") |
|
|
cache_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"num_directories": len(dicom_dirs), |
|
|
"directories": dicom_dirs |
|
|
} |
|
|
with open(cache_file, 'w') as f: |
|
|
json.dump(cache_data, f, indent=2) |
|
|
print(f"✓ Cache saved with {len(dicom_dirs)} directories\n") |
|
|
|
|
|
def load_directory_cache(cache_file): |
|
|
""" |
|
|
Load the list of DICOM directories from a cache file. |
|
|
|
|
|
Args: |
|
|
cache_file: Path to cache file |
|
|
|
|
|
Returns: |
|
|
List of directory paths, or None if cache doesn't exist or is invalid |
|
|
""" |
|
|
if not os.path.exists(cache_file): |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(cache_file, 'r') as f: |
|
|
cache_data = json.load(f) |
|
|
|
|
|
dicom_dirs = cache_data.get("directories", []) |
|
|
timestamp = cache_data.get("timestamp", "unknown") |
|
|
|
|
|
print(f"\n✓ Loaded directory cache from {cache_file}") |
|
|
print(f" Cache created: {timestamp}") |
|
|
print(f" Directories: {len(dicom_dirs)}\n") |
|
|
|
|
|
return dicom_dirs |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load cache: {e}") |
|
|
return None |
|
|
|
|
|
def find_dicom_directories(root_dir, max_subjects=None, num_workers=12, cache_file=None, filter_pids=None): |
|
|
""" |
|
|
Walk through directory tree and find all directories containing DICOM files. |
|
|
Uses parallel processing for much faster scanning of large directory trees. |
|
|
Only returns leaf directories (directories with .dcm files, not their parents). |
|
|
Filters out localizer scans with 1-2 DICOM files. |
|
|
|
|
|
Args: |
|
|
root_dir: Root directory to search |
|
|
max_subjects: Optional maximum number of unique subjects to process (None = all) |
|
|
num_workers: Number of parallel workers for directory scanning (default: 12) |
|
|
cache_file: Optional path to cache file for saving/loading directory list |
|
|
filter_pids: Optional set of PIDs to filter (only include these subjects) |
|
|
|
|
|
Returns: |
|
|
List of directory paths containing .dcm files |
|
|
""" |
|
|
|
|
|
if cache_file: |
|
|
cached_dirs = load_directory_cache(cache_file) |
|
|
if cached_dirs is not None: |
|
|
print("✓ Using cached directory list (skipping scan)") |
|
|
|
|
|
|
|
|
if filter_pids: |
|
|
print(f" Filtering to {len(filter_pids)} PIDs from CSV...") |
|
|
filtered_dirs = [] |
|
|
for d in cached_dirs: |
|
|
|
|
|
path_parts = d.rstrip('/').split('/') |
|
|
try: |
|
|
nlst_idx = path_parts.index('NLST') |
|
|
subject_id = path_parts[nlst_idx + 1] |
|
|
except (ValueError, IndexError): |
|
|
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] |
|
|
|
|
|
if subject_id in filter_pids: |
|
|
filtered_dirs.append(d) |
|
|
print(f" ✓ Found {len(filtered_dirs)} scans matching PIDs") |
|
|
return filtered_dirs |
|
|
|
|
|
|
|
|
if max_subjects: |
|
|
subjects_seen = set() |
|
|
filtered_dirs = [] |
|
|
for d in cached_dirs: |
|
|
|
|
|
path_parts = d.rstrip('/').split('/') |
|
|
try: |
|
|
nlst_idx = path_parts.index('NLST') |
|
|
subject_id = path_parts[nlst_idx + 1] |
|
|
except (ValueError, IndexError): |
|
|
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] |
|
|
|
|
|
|
|
|
|
|
|
if subject_id in subjects_seen: |
|
|
|
|
|
filtered_dirs.append(d) |
|
|
elif len(subjects_seen) < max_subjects: |
|
|
|
|
|
subjects_seen.add(subject_id) |
|
|
filtered_dirs.append(d) |
|
|
|
|
|
|
|
|
if len(subjects_seen) >= max_subjects: |
|
|
|
|
|
remaining_count = 0 |
|
|
for remaining_d in cached_dirs[cached_dirs.index(d)+1:]: |
|
|
remaining_parts = remaining_d.rstrip('/').split('/') |
|
|
try: |
|
|
remaining_nlst_idx = remaining_parts.index('NLST') |
|
|
remaining_subject_id = remaining_parts[remaining_nlst_idx + 1] |
|
|
except (ValueError, IndexError): |
|
|
remaining_subject_id = remaining_parts[-3] if len(remaining_parts) >= 3 else remaining_parts[-1] |
|
|
if remaining_subject_id in subjects_seen: |
|
|
filtered_dirs.append(remaining_d) |
|
|
break |
|
|
|
|
|
print(f" ✓ Limited to {len(subjects_seen)} subjects ({len(filtered_dirs)} total scans)") |
|
|
return filtered_dirs |
|
|
return cached_dirs |
|
|
|
|
|
print(f"Starting parallel directory scan with {num_workers} workers...") |
|
|
if filter_pids: |
|
|
print(f"⚡ FAST MODE: Only scanning {len(filter_pids)} PIDs (skipping others)") |
|
|
else: |
|
|
print("Scanning ALL subjects (this may take a while)") |
|
|
|
|
|
|
|
|
|
|
|
print("\nPhase 1: Scanning filesystem for DICOM directories...") |
|
|
start_time = datetime.now() |
|
|
|
|
|
|
|
|
all_dirs = [] |
|
|
for dirpath, dirnames, filenames in os.walk(root_dir): |
|
|
|
|
|
if filter_pids: |
|
|
path_parts = dirpath.rstrip('/').split('/') |
|
|
try: |
|
|
nlst_idx = path_parts.index('NLST') |
|
|
|
|
|
if len(path_parts) == nlst_idx + 2: |
|
|
subject_id = path_parts[nlst_idx + 1] |
|
|
|
|
|
if subject_id not in filter_pids: |
|
|
dirnames.clear() |
|
|
continue |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
|
|
|
if any(f.endswith('.dcm') for f in filenames): |
|
|
all_dirs.append(dirpath) |
|
|
|
|
|
print(f"Found {len(all_dirs)} potential DICOM directories in {(datetime.now() - start_time).total_seconds():.1f}s") |
|
|
|
|
|
|
|
|
print(f"\nPhase 2: Validating directories in parallel ({num_workers} workers)...") |
|
|
|
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
|
|
|
dicom_dirs = [] |
|
|
subjects_found = set() |
|
|
filtered_stats = {'localizers': 0, 'too_few_slices': 0} |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=num_workers) as executor: |
|
|
|
|
|
future_to_dir = {executor.submit(check_directory_for_dicoms, d): d for d in all_dirs} |
|
|
|
|
|
|
|
|
for i, future in enumerate(as_completed(future_to_dir), 1): |
|
|
|
|
|
if i % 1000 == 0: |
|
|
elapsed = (datetime.now() - start_time).total_seconds() |
|
|
rate = i / elapsed if elapsed > 0 else 0 |
|
|
remaining = (len(all_dirs) - i) / rate if rate > 0 else 0 |
|
|
print(f" [{i}/{len(all_dirs)}] Found: {len(dicom_dirs)} scans from {len(subjects_found)} PIDs | " |
|
|
f"Filtered: {filtered_stats['localizers'] + filtered_stats['too_few_slices']} | " |
|
|
f"ETA: {remaining/60:.1f} min") |
|
|
|
|
|
try: |
|
|
result = future.result() |
|
|
if result is None: |
|
|
continue |
|
|
|
|
|
dirpath, num_files, subject_id, status = result |
|
|
|
|
|
if status == 'too_few_slices': |
|
|
filtered_stats['too_few_slices'] += 1 |
|
|
elif status == 'localizer': |
|
|
filtered_stats['localizers'] += 1 |
|
|
elif status == 'valid': |
|
|
|
|
|
if filter_pids is not None and subject_id not in filter_pids: |
|
|
continue |
|
|
|
|
|
|
|
|
if max_subjects is not None and subject_id not in subjects_found and len(subjects_found) >= max_subjects: |
|
|
continue |
|
|
|
|
|
subjects_found.add(subject_id) |
|
|
dicom_dirs.append(dirpath) |
|
|
|
|
|
|
|
|
if filter_pids and len(dicom_dirs) % 100 == 1: |
|
|
print(f" ✓ Found {len(dicom_dirs)} scans so far ({len(subjects_found)} unique PIDs)") |
|
|
|
|
|
|
|
|
if max_subjects is not None and len(subjects_found) >= max_subjects: |
|
|
print(f"\n✓ Reached limit of {max_subjects} subjects. Stopping search.") |
|
|
|
|
|
for f in future_to_dir: |
|
|
f.cancel() |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
continue |
|
|
|
|
|
scan_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"Directory Scan Complete in {scan_time:.1f}s ({scan_time/60:.1f} minutes)") |
|
|
print(f"{'='*80}") |
|
|
print(f"Filtering Summary:") |
|
|
print(f" ✅ Valid scans found: {len(dicom_dirs)}") |
|
|
print(f" 🚫 Localizers filtered: {filtered_stats['localizers']}") |
|
|
print(f" ⏭️ Too few slices (≤2) filtered: {filtered_stats['too_few_slices']}") |
|
|
print(f" 📊 Unique subjects: {len(subjects_found)}") |
|
|
print(f" ⚡ Speed: {len(all_dirs)/scan_time:.0f} dirs/second") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
if cache_file: |
|
|
save_directory_cache(dicom_dirs, cache_file) |
|
|
|
|
|
return dicom_dirs |
|
|
|
|
|
def prepare_scan_metadata(scan_dir): |
|
|
""" |
|
|
Prepare metadata for a scan without processing. |
|
|
|
|
|
Args: |
|
|
scan_dir: Directory containing DICOM files for one scan |
|
|
|
|
|
Returns: |
|
|
tuple: (dicom_file_paths, num_files, subject_id, scan_id) |
|
|
""" |
|
|
|
|
|
dicom_files = [f for f in os.listdir(scan_dir) |
|
|
if f.endswith('.dcm') and os.path.isfile(os.path.join(scan_dir, f))] |
|
|
num_dicom_files = len(dicom_files) |
|
|
|
|
|
if num_dicom_files == 0: |
|
|
raise ValueError("No valid DICOM files found") |
|
|
|
|
|
|
|
|
dicom_file_paths = [os.path.join(scan_dir, f) for f in dicom_files] |
|
|
|
|
|
|
|
|
|
|
|
path_parts = scan_dir.rstrip('/').split('/') |
|
|
scan_id = path_parts[-1] if path_parts[-1] else path_parts[-2] |
|
|
|
|
|
|
|
|
try: |
|
|
nlst_idx = path_parts.index('NLST') |
|
|
subject_id = path_parts[nlst_idx + 1] |
|
|
except (ValueError, IndexError): |
|
|
|
|
|
subject_id = path_parts[-3] if len(path_parts) >= 3 else path_parts[-1] |
|
|
|
|
|
return dicom_file_paths, num_dicom_files, subject_id, scan_id |
|
|
|
|
|
def save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_num): |
|
|
""" |
|
|
Save a checkpoint of embeddings and metadata. |
|
|
|
|
|
Args: |
|
|
all_embeddings: List of embedding arrays |
|
|
all_metadata: List of metadata dictionaries |
|
|
failed: List of failed scans |
|
|
output_dir: Output directory |
|
|
checkpoint_num: Checkpoint number |
|
|
""" |
|
|
print(f"\n💾 Saving checkpoint {checkpoint_num}...") |
|
|
|
|
|
|
|
|
embeddings_array = np.array(all_embeddings) |
|
|
embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0]) |
|
|
|
|
|
|
|
|
df_data = { |
|
|
'case_number': [m['case_number'] for m in all_metadata], |
|
|
'subject_id': [m['subject_id'] for m in all_metadata], |
|
|
'scan_id': [m['scan_id'] for m in all_metadata], |
|
|
'timepoint': [m.get('timepoint') for m in all_metadata], |
|
|
'dicom_directory': [m['dicom_directory'] for m in all_metadata], |
|
|
'num_dicom_files': [m['num_dicom_files'] for m in all_metadata], |
|
|
'embedding_index': [m['embedding_index'] for m in all_metadata], |
|
|
'embedding': list(embeddings_array) |
|
|
} |
|
|
df = pd.DataFrame(df_data) |
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_embeddings.parquet") |
|
|
df.to_parquet(checkpoint_path, index=False, compression='snappy') |
|
|
print(f" ✓ Saved embeddings checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint_metadata = { |
|
|
"checkpoint_num": checkpoint_num, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"total_scans": len(all_embeddings), |
|
|
"failed_scans": len(failed), |
|
|
"embedding_shape": list(embeddings_array.shape), |
|
|
"scans": all_metadata, |
|
|
"failed_scans": failed |
|
|
} |
|
|
metadata_path = os.path.join(output_dir, f"checkpoint_{checkpoint_num}_metadata.json") |
|
|
with open(metadata_path, 'w') as f: |
|
|
json.dump(checkpoint_metadata, f, indent=2) |
|
|
print(f" ✓ Saved metadata checkpoint: {metadata_path}") |
|
|
print(f"💾 Checkpoint {checkpoint_num} complete!\n") |
|
|
|
|
|
def process_scan(model, device, scan_dir): |
|
|
""" |
|
|
Process a single scan directory and extract embeddings. |
|
|
|
|
|
Args: |
|
|
model: Pre-loaded SybilHFWrapper model |
|
|
device: Device to run on (e.g., cuda:0, cuda:1) |
|
|
scan_dir: Directory containing DICOM files for one scan |
|
|
|
|
|
Returns: |
|
|
tuple: (embeddings, scan_metadata) |
|
|
""" |
|
|
dicom_file_paths, num_dicom_files, subject_id, scan_id = prepare_scan_metadata(scan_dir) |
|
|
|
|
|
print(f"\nProcessing: {scan_dir}") |
|
|
print(f"DICOM files: {num_dicom_files}") |
|
|
|
|
|
|
|
|
embeddings = extract_embeddings(model, dicom_file_paths, device) |
|
|
|
|
|
print(f"Embedding shape: {embeddings.shape}") |
|
|
|
|
|
|
|
|
timepoint = extract_timepoint_from_path(scan_dir) |
|
|
if timepoint: |
|
|
print(f"Timepoint: {timepoint}") |
|
|
else: |
|
|
print(f"Timepoint: Not detected") |
|
|
|
|
|
|
|
|
scan_metadata = { |
|
|
"case_number": subject_id, |
|
|
"subject_id": subject_id, |
|
|
"scan_id": scan_id, |
|
|
"timepoint": timepoint, |
|
|
"dicom_directory": scan_dir, |
|
|
"num_dicom_files": num_dicom_files, |
|
|
"embedding_index": None, |
|
|
"statistics": { |
|
|
"mean": float(np.mean(embeddings)), |
|
|
"std": float(np.std(embeddings)), |
|
|
"min": float(np.min(embeddings)), |
|
|
"max": float(np.max(embeddings)) |
|
|
} |
|
|
} |
|
|
|
|
|
return embeddings, scan_metadata |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Extract Sybil embeddings from DICOM scans') |
|
|
|
|
|
|
|
|
parser.add_argument('--root-dir', type=str, required=True, |
|
|
help='Root directory containing DICOM files (e.g., /path/to/NLST)') |
|
|
parser.add_argument('--pid-csv', type=str, default=None, |
|
|
help='CSV file with "pid" column to filter subjects (e.g., subsets/hybridModels-train.csv)') |
|
|
parser.add_argument('--output-dir', type=str, default='embeddings_output', |
|
|
help='Output directory for embeddings (default: embeddings_output)') |
|
|
parser.add_argument('--max-subjects', type=int, default=None, |
|
|
help='Maximum number of subjects to process (for testing)') |
|
|
|
|
|
|
|
|
parser.add_argument('--num-gpus', type=int, default=1, |
|
|
help='Number of GPUs to use (default: 1)') |
|
|
parser.add_argument('--num-parallel', type=int, default=1, |
|
|
help='Number of parallel scans to process simultaneously (default: 1, recommended: 1-4 depending on GPU memory)') |
|
|
parser.add_argument('--num-workers', type=int, default=4, |
|
|
help='Number of parallel workers for directory scanning (default: 4, recommended: 4-12 depending on storage speed)') |
|
|
parser.add_argument('--checkpoint-interval', type=int, default=1000, |
|
|
help='Save checkpoint every N scans (default: 1000)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_dir = args.root_dir |
|
|
output_dir = args.output_dir |
|
|
max_subjects = args.max_subjects |
|
|
num_gpus = args.num_gpus |
|
|
num_parallel_scans = args.num_parallel |
|
|
num_scan_workers = args.num_workers |
|
|
checkpoint_interval = args.checkpoint_interval |
|
|
|
|
|
|
|
|
main_cache = "embeddings_output_full/directory_cache.json" |
|
|
if os.path.exists(main_cache): |
|
|
cache_file = main_cache |
|
|
print(f"✓ Found main directory cache: {main_cache}") |
|
|
else: |
|
|
cache_file = os.path.join(output_dir, "directory_cache.json") |
|
|
|
|
|
|
|
|
if not os.path.exists(root_dir): |
|
|
raise ValueError(f"Root directory does not exist: {root_dir}") |
|
|
|
|
|
|
|
|
filter_pids = None |
|
|
if args.pid_csv: |
|
|
print(f"Loading subject PIDs from: {args.pid_csv}") |
|
|
import pandas as pd |
|
|
csv_data = pd.read_csv(args.pid_csv) |
|
|
filter_pids = set(str(pid) for pid in csv_data['pid'].unique()) |
|
|
print(f" Found {len(filter_pids)} unique PIDs to extract") |
|
|
print(f" Examples: {list(filter_pids)[:5]}") |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"CONFIGURATION") |
|
|
print(f"{'='*80}") |
|
|
print(f"Root directory: {root_dir}") |
|
|
print(f"Output directory: {output_dir}") |
|
|
print(f"Number of GPUs: {num_gpus}") |
|
|
print(f"Parallel scans: {num_parallel_scans} (recommended: 1-4 depending on GPU memory)") |
|
|
print(f"Directory scan workers: {num_scan_workers} (recommended: 4-12 depending on storage)") |
|
|
print(f"Checkpoint interval: {checkpoint_interval} scans") |
|
|
if filter_pids: |
|
|
print(f"Filtering to: {len(filter_pids)} PIDs from CSV") |
|
|
if max_subjects: |
|
|
print(f"Max subjects: {max_subjects}") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
if num_parallel_scans > 1: |
|
|
estimated_vram = (num_parallel_scans // num_gpus) * 10 |
|
|
print(f"⚠️ MEMORY WARNING:") |
|
|
print(f" Parallel processing requires ~{estimated_vram}GB VRAM per GPU") |
|
|
print(f" If you encounter OOM errors, reduce --num-parallel to 1-2") |
|
|
print(f" Current: {num_parallel_scans} scans across {num_gpus} GPU(s)\n") |
|
|
|
|
|
|
|
|
|
|
|
dicom_dirs = find_dicom_directories(root_dir, max_subjects=max_subjects, |
|
|
num_workers=num_scan_workers, cache_file=cache_file, |
|
|
filter_pids=filter_pids) |
|
|
|
|
|
if len(dicom_dirs) == 0: |
|
|
raise ValueError(f"No directories with DICOM files found in {root_dir}") |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"Found {len(dicom_dirs)} directories containing DICOM files") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
print(f"🎮 Detected {num_gpus} GPU(s)") |
|
|
print(f"🚀 Will process {num_parallel_scans} scans in parallel ({num_parallel_scans // num_gpus} per GPU)") |
|
|
print(f"💾 Checkpoints will be saved every {checkpoint_interval} scans\n") |
|
|
|
|
|
|
|
|
models_and_devices = [] |
|
|
for gpu_id in range(num_gpus): |
|
|
model, device = load_model(gpu_id) |
|
|
models_and_devices.append((model, device, gpu_id)) |
|
|
|
|
|
|
|
|
all_embeddings = [] |
|
|
all_metadata = [] |
|
|
failed = [] |
|
|
checkpoint_counter = 0 |
|
|
|
|
|
if num_parallel_scans > 1: |
|
|
|
|
|
print(f"Processing {num_parallel_scans} scans in parallel across {num_gpus} GPU(s)...") |
|
|
print(f"Note: This requires ~{(num_parallel_scans // num_gpus) * 10}GB VRAM per GPU.\n") |
|
|
|
|
|
from functools import partial |
|
|
from concurrent.futures import as_completed |
|
|
|
|
|
|
|
|
batch_size = checkpoint_interval |
|
|
num_batches = (len(dicom_dirs) + batch_size - 1) // batch_size |
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min(start_idx + batch_size, len(dicom_dirs)) |
|
|
batch_dirs = dicom_dirs[start_idx:end_idx] |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"Processing batch {batch_idx + 1}/{num_batches} (scans {start_idx + 1} to {end_idx})") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=num_parallel_scans) as executor: |
|
|
|
|
|
|
|
|
future_to_info = {} |
|
|
scan_queue = list(enumerate(batch_dirs)) |
|
|
scans_submitted = 0 |
|
|
|
|
|
|
|
|
while scan_queue and scans_submitted < num_parallel_scans: |
|
|
i, scan_dir = scan_queue.pop(0) |
|
|
|
|
|
gpu_idx = i % num_gpus |
|
|
model, device, gpu_id = models_and_devices[gpu_idx] |
|
|
|
|
|
|
|
|
process_func = partial(process_scan, model, device) |
|
|
future = executor.submit(process_func, scan_dir) |
|
|
future_to_info[future] = (start_idx + i + 1, scan_dir, gpu_id) |
|
|
scans_submitted += 1 |
|
|
|
|
|
|
|
|
while future_to_info: |
|
|
|
|
|
done_futures = [] |
|
|
for future in list(future_to_info.keys()): |
|
|
if future.done(): |
|
|
done_futures.append(future) |
|
|
|
|
|
if not done_futures: |
|
|
import time |
|
|
time.sleep(0.1) |
|
|
continue |
|
|
|
|
|
|
|
|
for future in done_futures: |
|
|
scan_num, scan_dir, gpu_id = future_to_info.pop(future) |
|
|
try: |
|
|
print(f"[{scan_num}/{len(dicom_dirs)}] Processing on GPU {gpu_id}...") |
|
|
embeddings, scan_metadata = future.result() |
|
|
|
|
|
|
|
|
scan_metadata["embedding_index"] = len(all_embeddings) |
|
|
|
|
|
|
|
|
all_embeddings.append(embeddings) |
|
|
all_metadata.append(scan_metadata) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR processing {scan_dir}: {e}") |
|
|
failed.append({"scan_dir": scan_dir, "error": str(e)}) |
|
|
|
|
|
|
|
|
if scan_queue: |
|
|
i, next_scan_dir = scan_queue.pop(0) |
|
|
gpu_idx = i % num_gpus |
|
|
model, device, gpu_id = models_and_devices[gpu_idx] |
|
|
|
|
|
process_func = partial(process_scan, model, device) |
|
|
new_future = executor.submit(process_func, next_scan_dir) |
|
|
future_to_info[new_future] = (start_idx + i + 1, next_scan_dir, gpu_id) |
|
|
|
|
|
|
|
|
checkpoint_counter += 1 |
|
|
save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter) |
|
|
|
|
|
print(f"Progress: {len(all_embeddings)}/{len(dicom_dirs)} scans completed " |
|
|
f"({len(all_embeddings)/len(dicom_dirs)*100:.1f}%)\n") |
|
|
else: |
|
|
|
|
|
model, device, gpu_id = models_and_devices[0] |
|
|
|
|
|
for i, scan_dir in enumerate(dicom_dirs, 1): |
|
|
try: |
|
|
print(f"\n[{i}/{len(dicom_dirs)}] Processing scan...") |
|
|
|
|
|
|
|
|
embeddings, scan_metadata = process_scan(model, device, scan_dir) |
|
|
|
|
|
|
|
|
scan_metadata["embedding_index"] = len(all_embeddings) |
|
|
|
|
|
|
|
|
all_embeddings.append(embeddings) |
|
|
all_metadata.append(scan_metadata) |
|
|
|
|
|
|
|
|
if i % checkpoint_interval == 0: |
|
|
checkpoint_counter += 1 |
|
|
save_checkpoint(all_embeddings, all_metadata, failed, output_dir, checkpoint_counter) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR processing {scan_dir}: {e}") |
|
|
failed.append({"scan_dir": scan_dir, "error": str(e)}) |
|
|
|
|
|
|
|
|
|
|
|
embeddings_array = np.array(all_embeddings) |
|
|
embedding_dim = int(embeddings_array.shape[1]) if len(embeddings_array.shape) > 1 else int(embeddings_array.shape[0]) |
|
|
|
|
|
|
|
|
|
|
|
df_data = { |
|
|
'case_number': [m['case_number'] for m in all_metadata], |
|
|
'subject_id': [m['subject_id'] for m in all_metadata], |
|
|
'scan_id': [m['scan_id'] for m in all_metadata], |
|
|
'timepoint': [m.get('timepoint') for m in all_metadata], |
|
|
'dicom_directory': [m['dicom_directory'] for m in all_metadata], |
|
|
'num_dicom_files': [m['num_dicom_files'] for m in all_metadata], |
|
|
'embedding_index': [m['embedding_index'] for m in all_metadata], |
|
|
'embedding': list(embeddings_array) |
|
|
} |
|
|
|
|
|
|
|
|
df = pd.DataFrame(df_data) |
|
|
|
|
|
|
|
|
embeddings_filename = "all_embeddings.parquet" |
|
|
embeddings_path = os.path.join(output_dir, embeddings_filename) |
|
|
df.to_parquet(embeddings_path, index=False, compression='snappy') |
|
|
print(f"\n✅ Saved FINAL embeddings to Parquet: {embeddings_path}") |
|
|
|
|
|
|
|
|
dataset_metadata = { |
|
|
"dataset_info": { |
|
|
"root_directory": root_dir, |
|
|
"total_scans": len(all_embeddings), |
|
|
"failed_scans": len(failed), |
|
|
"embedding_shape": list(embeddings_array.shape), |
|
|
"embedding_dim": embedding_dim, |
|
|
"extraction_timestamp": datetime.now().isoformat(), |
|
|
"file_format": "parquet" |
|
|
}, |
|
|
"model_info": { |
|
|
"model": "Lab-Rasool/sybil", |
|
|
"layer": "after_relu_before_dropout", |
|
|
"ensemble_averaged": True, |
|
|
"num_ensemble_models": 5 |
|
|
}, |
|
|
"embeddings_file": embeddings_filename, |
|
|
"parquet_schema": { |
|
|
"metadata_columns": ["case_number", "subject_id", "scan_id", "timepoint", "dicom_directory", "num_dicom_files", "embedding_index"], |
|
|
"embedding_column": "embedding", |
|
|
"embedding_shape": f"({embedding_dim},)", |
|
|
"total_columns": 8, |
|
|
"timepoint_info": "T0=1999, T1=2000, T2=2001, etc. Extracted from year in path. Can be None if not detected." |
|
|
}, |
|
|
"filtering_info": { |
|
|
"localizer_detection": "Scans identified as localizers (by folder name or DICOM metadata) are filtered out", |
|
|
"min_slices": "Scans with ≤2 DICOM files are filtered out (likely localizers)", |
|
|
"accepted_scans": len(all_embeddings) |
|
|
}, |
|
|
"scans": all_metadata, |
|
|
"failed_scans": failed |
|
|
} |
|
|
|
|
|
metadata_filename = "dataset_metadata.json" |
|
|
metadata_path = os.path.join(output_dir, metadata_filename) |
|
|
with open(metadata_path, 'w') as f: |
|
|
json.dump(dataset_metadata, f, indent=2) |
|
|
print(f"✅ Saved FINAL metadata: {metadata_path}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"PROCESSING COMPLETE") |
|
|
print(f"{'='*80}") |
|
|
print(f"Successfully processed: {len(all_embeddings)}/{len(dicom_dirs)} scans") |
|
|
print(f"Failed: {len(failed)}/{len(dicom_dirs)} scans") |
|
|
print(f"\nEmbeddings array shape: {embeddings_array.shape}") |
|
|
print(f"Saved embeddings to: {embeddings_path}") |
|
|
print(f"Saved metadata to: {metadata_path}") |
|
|
|
|
|
|
|
|
timepoint_counts = {} |
|
|
for m in all_metadata: |
|
|
tp = m.get('timepoint', 'Unknown') |
|
|
timepoint_counts[tp] = timepoint_counts.get(tp, 0) + 1 |
|
|
|
|
|
if timepoint_counts: |
|
|
print(f"\n📅 Timepoint Distribution:") |
|
|
for tp in sorted(timepoint_counts.keys(), key=lambda x: (x is None, x)): |
|
|
count = timepoint_counts[tp] |
|
|
if tp is None: |
|
|
print(f" Unknown/Not detected: {count} scans") |
|
|
else: |
|
|
print(f" {tp}: {count} scans") |
|
|
|
|
|
if failed: |
|
|
print(f"\nFailed scans: {len(failed)}") |
|
|
for fail_info in failed[:5]: |
|
|
print(f" - {fail_info['scan_dir']}") |
|
|
print(f" Error: {fail_info['error']}") |
|
|
if len(failed) > 5: |
|
|
print(f" ... and {len(failed) - 5} more failures") |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"For downstream training, load embeddings with:") |
|
|
print(f" import pandas as pd") |
|
|
print(f" import numpy as np") |
|
|
print(f" df = pd.read_parquet('{embeddings_path}')") |
|
|
print(f" # Total rows: {len(df)}, Total columns: {len(df.columns)}") |
|
|
print(f" # Extract embeddings array: embeddings = np.stack(df['embedding'].values)") |
|
|
print(f" # Shape: {embeddings_array.shape}") |
|
|
print(f" # Access individual: df.loc[0, 'embedding'] -> array of shape ({embedding_dim},)") |
|
|
print(f"{'='*80}") |
|
|
|