|
""" |
|
Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model. |
|
This version works directly from HF without requiring external Sybil package. |
|
""" |
|
|
|
import os |
|
import json |
|
import sys |
|
import torch |
|
import numpy as np |
|
from typing import List, Dict, Optional |
|
from dataclasses import dataclass |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from safetensors.torch import load_file |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
if current_dir not in sys.path: |
|
sys.path.insert(0, current_dir) |
|
|
|
try: |
|
from .configuration_sybil import SybilConfig |
|
from .modeling_sybil import SybilForRiskPrediction |
|
from .image_processing_sybil import SybilImageProcessor |
|
except ImportError: |
|
from configuration_sybil import SybilConfig |
|
from modeling_sybil import SybilForRiskPrediction |
|
from image_processing_sybil import SybilImageProcessor |
|
|
|
|
|
@dataclass |
|
class SybilOutput(BaseModelOutput): |
|
""" |
|
Output class for Sybil model predictions. |
|
|
|
Args: |
|
risk_scores: Risk scores for each year (1-6 years by default) |
|
attentions: Optional attention maps if requested |
|
""" |
|
risk_scores: torch.FloatTensor = None |
|
attentions: Optional[Dict] = None |
|
|
|
|
|
class SybilHFWrapper: |
|
""" |
|
Hugging Face wrapper for Sybil ensemble model. |
|
Provides a simple interface for lung cancer risk prediction from CT scans. |
|
""" |
|
|
|
def __init__(self, config: SybilConfig = None): |
|
""" |
|
Initialize the Sybil model ensemble. |
|
|
|
Args: |
|
config: Model configuration (will use default if not provided) |
|
""" |
|
self.config = config if config is not None else SybilConfig() |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
self.image_processor = SybilImageProcessor() |
|
|
|
|
|
self.calibrator = self._load_calibrator() |
|
|
|
|
|
self.models = self._load_ensemble_models() |
|
|
|
def _load_calibrator(self) -> Dict: |
|
"""Load ensemble calibrator data""" |
|
calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json") |
|
|
|
if os.path.exists(calibrator_path): |
|
with open(calibrator_path, 'r') as f: |
|
return json.load(f) |
|
else: |
|
|
|
calibrator_path = os.path.join(self.model_dir, "calibrator_data.json") |
|
if os.path.exists(calibrator_path): |
|
with open(calibrator_path, 'r') as f: |
|
return json.load(f) |
|
return {} |
|
|
|
def _load_ensemble_models(self) -> List[torch.nn.Module]: |
|
"""Load all models in the ensemble from safetensors files""" |
|
models = [] |
|
|
|
|
|
for i in range(1, 6): |
|
model_subdir = os.path.join(self.model_dir, f"sybil_{i}") |
|
weights_path = os.path.join(model_subdir, "model.safetensors") |
|
|
|
if os.path.exists(weights_path): |
|
|
|
model = SybilForRiskPrediction(self.config) |
|
|
|
|
|
try: |
|
state_dict = load_file(weights_path) |
|
model.load_state_dict(state_dict, strict=False) |
|
except Exception as e: |
|
print(f"Warning: Could not load weights for sybil_{i}: {e}") |
|
continue |
|
|
|
model.to(self.device) |
|
model.eval() |
|
models.append(model) |
|
else: |
|
|
|
checkpoint_path = os.path.join(self.model_dir, "checkpoints", f"sybil_{i}.ckpt") |
|
if os.path.exists(checkpoint_path): |
|
model = SybilForRiskPrediction(self.config) |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
else: |
|
state_dict = checkpoint |
|
|
|
|
|
cleaned_state_dict = {} |
|
for k, v in state_dict.items(): |
|
if k.startswith('model.'): |
|
cleaned_state_dict[k[6:]] = v |
|
else: |
|
cleaned_state_dict[k] = v |
|
|
|
model.load_state_dict(cleaned_state_dict, strict=False) |
|
model.to(self.device) |
|
model.eval() |
|
models.append(model) |
|
|
|
if not models: |
|
raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.") |
|
|
|
print(f"Loaded {len(models)} models in ensemble") |
|
return models |
|
|
|
def _apply_calibration(self, scores: np.ndarray) -> np.ndarray: |
|
""" |
|
Apply calibration to raw model outputs. |
|
|
|
Args: |
|
scores: Raw risk scores from the model |
|
|
|
Returns: |
|
Calibrated risk scores |
|
""" |
|
if not self.calibrator: |
|
return scores |
|
|
|
calibrated = np.zeros_like(scores) |
|
|
|
for year in range(scores.shape[1]): |
|
year_key = f"Year{year + 1}" |
|
if year_key in self.calibrator: |
|
cal_data = self.calibrator[year_key] |
|
if isinstance(cal_data, list) and len(cal_data) > 0: |
|
cal_data = cal_data[0] |
|
|
|
|
|
if isinstance(cal_data, dict) and "coef" in cal_data and "intercept" in cal_data: |
|
coef = cal_data["coef"][0][0] if isinstance(cal_data["coef"], list) else cal_data["coef"] |
|
intercept = cal_data["intercept"][0] if isinstance(cal_data["intercept"], list) else cal_data["intercept"] |
|
|
|
|
|
calibrated[:, year] = scores[:, year] * coef + intercept |
|
calibrated[:, year] = 1 / (1 + np.exp(-calibrated[:, year])) |
|
else: |
|
calibrated[:, year] = scores[:, year] |
|
else: |
|
calibrated[:, year] = scores[:, year] |
|
|
|
return calibrated |
|
|
|
def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor: |
|
""" |
|
Preprocess DICOM files for model input. |
|
|
|
Args: |
|
dicom_paths: List of paths to DICOM files |
|
|
|
Returns: |
|
Preprocessed tensor ready for model input |
|
""" |
|
|
|
result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt") |
|
pixel_values = result["pixel_values"] |
|
|
|
|
|
if pixel_values.ndim == 4: |
|
pixel_values = pixel_values.unsqueeze(0) |
|
|
|
return pixel_values.to(self.device) |
|
|
|
def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput: |
|
""" |
|
Run prediction on a CT scan series. |
|
|
|
Args: |
|
dicom_paths: List of paths to DICOM files for a single CT series |
|
return_attentions: Whether to return attention maps |
|
|
|
Returns: |
|
SybilOutput with risk scores and optional attention maps |
|
""" |
|
|
|
pixel_values = self.preprocess_dicom(dicom_paths) |
|
|
|
|
|
all_predictions = [] |
|
all_attentions = [] |
|
|
|
with torch.no_grad(): |
|
for model in self.models: |
|
output = model( |
|
pixel_values=pixel_values, |
|
return_attentions=return_attentions |
|
) |
|
|
|
|
|
if hasattr(output, 'risk_scores'): |
|
predictions = output.risk_scores |
|
else: |
|
predictions = output[0] if isinstance(output, tuple) else output |
|
|
|
all_predictions.append(predictions.cpu().numpy()) |
|
|
|
if return_attentions and hasattr(output, 'image_attention'): |
|
all_attentions.append(output.image_attention) |
|
|
|
|
|
ensemble_pred = np.mean(all_predictions, axis=0) |
|
|
|
|
|
calibrated_pred = self._apply_calibration(ensemble_pred) |
|
|
|
|
|
risk_scores = torch.from_numpy(calibrated_pred).float() |
|
|
|
|
|
attentions = None |
|
if return_attentions and all_attentions: |
|
attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)} |
|
|
|
return SybilOutput(risk_scores=risk_scores, attentions=attentions) |
|
|
|
def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput: |
|
""" |
|
Convenience method for prediction. |
|
|
|
Args: |
|
dicom_paths: List of DICOM file paths for a single series |
|
dicom_series: List of lists of DICOM paths for batch processing |
|
**kwargs: Additional arguments passed to predict() |
|
|
|
Returns: |
|
SybilOutput with predictions |
|
""" |
|
if dicom_series is not None: |
|
|
|
all_outputs = [] |
|
for paths in dicom_series: |
|
output = self.predict(paths, **kwargs) |
|
all_outputs.append(output.risk_scores) |
|
|
|
risk_scores = torch.stack(all_outputs) |
|
return SybilOutput(risk_scores=risk_scores) |
|
elif dicom_paths is not None: |
|
return self.predict(dicom_paths, **kwargs) |
|
else: |
|
raise ValueError("Either dicom_paths or dicom_series must be provided") |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
""" |
|
Load model from Hugging Face hub or local path. |
|
|
|
Args: |
|
pretrained_model_name_or_path: HF model ID or local path |
|
**kwargs: Additional configuration arguments |
|
|
|
Returns: |
|
SybilHFWrapper instance |
|
""" |
|
|
|
config = kwargs.pop("config", None) |
|
if config is None: |
|
try: |
|
config = SybilConfig.from_pretrained(pretrained_model_name_or_path) |
|
except: |
|
config = SybilConfig() |
|
|
|
return cls(config=config) |