Aakash-Tripathi commited on
Commit
a0f3988
·
verified ·
1 Parent(s): d78f088

Delete modeling_sybil_hf.py

Browse files
Files changed (1) hide show
  1. modeling_sybil_hf.py +0 -298
modeling_sybil_hf.py DELETED
@@ -1,298 +0,0 @@
1
- """
2
- Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model.
3
- This version works directly from HF without requiring external Sybil package.
4
- """
5
-
6
- import os
7
- import json
8
- import sys
9
- import torch
10
- import numpy as np
11
- from typing import List, Dict, Optional
12
- from dataclasses import dataclass
13
- from transformers.modeling_outputs import BaseModelOutput
14
- from safetensors.torch import load_file
15
-
16
- # Add model path to sys.path for imports
17
- current_dir = os.path.dirname(os.path.abspath(__file__))
18
- if current_dir not in sys.path:
19
- sys.path.insert(0, current_dir)
20
-
21
- try:
22
- from .configuration_sybil import SybilConfig
23
- from .modeling_sybil import SybilForRiskPrediction
24
- from .image_processing_sybil import SybilImageProcessor
25
- except ImportError:
26
- from configuration_sybil import SybilConfig
27
- from modeling_sybil import SybilForRiskPrediction
28
- from image_processing_sybil import SybilImageProcessor
29
-
30
-
31
- @dataclass
32
- class SybilOutput(BaseModelOutput):
33
- """
34
- Output class for Sybil model predictions.
35
-
36
- Args:
37
- risk_scores: Risk scores for each year (1-6 years by default)
38
- attentions: Optional attention maps if requested
39
- """
40
- risk_scores: torch.FloatTensor = None
41
- attentions: Optional[Dict] = None
42
-
43
-
44
- class SybilHFWrapper:
45
- """
46
- Hugging Face wrapper for Sybil ensemble model.
47
- Provides a simple interface for lung cancer risk prediction from CT scans.
48
- """
49
-
50
- def __init__(self, config: SybilConfig = None):
51
- """
52
- Initialize the Sybil model ensemble.
53
-
54
- Args:
55
- config: Model configuration (will use default if not provided)
56
- """
57
- self.config = config if config is not None else SybilConfig()
58
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
-
60
- # Get the directory where this file is located
61
- self.model_dir = os.path.dirname(os.path.abspath(__file__))
62
-
63
- # Initialize image processor
64
- self.image_processor = SybilImageProcessor()
65
-
66
- # Load calibrator
67
- self.calibrator = self._load_calibrator()
68
-
69
- # Load ensemble models
70
- self.models = self._load_ensemble_models()
71
-
72
- def _load_calibrator(self) -> Dict:
73
- """Load ensemble calibrator data"""
74
- calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json")
75
-
76
- if os.path.exists(calibrator_path):
77
- with open(calibrator_path, 'r') as f:
78
- return json.load(f)
79
- else:
80
- # Try alternative location
81
- calibrator_path = os.path.join(self.model_dir, "calibrator_data.json")
82
- if os.path.exists(calibrator_path):
83
- with open(calibrator_path, 'r') as f:
84
- return json.load(f)
85
- return {}
86
-
87
- def _load_ensemble_models(self) -> List[torch.nn.Module]:
88
- """Load all models in the ensemble from safetensors files"""
89
- models = []
90
-
91
- # Load each model in the ensemble (Sybil uses 5 models)
92
- for i in range(1, 6):
93
- model_subdir = os.path.join(self.model_dir, f"sybil_{i}")
94
- weights_path = os.path.join(model_subdir, "model.safetensors")
95
-
96
- if os.path.exists(weights_path):
97
- # Create model instance
98
- model = SybilForRiskPrediction(self.config)
99
-
100
- # Load weights from safetensors
101
- try:
102
- state_dict = load_file(weights_path)
103
- model.load_state_dict(state_dict, strict=False)
104
- except Exception as e:
105
- print(f"Warning: Could not load weights for sybil_{i}: {e}")
106
- continue
107
-
108
- model.to(self.device)
109
- model.eval()
110
- models.append(model)
111
- else:
112
- # Try loading from checkpoints directory
113
- checkpoint_path = os.path.join(self.model_dir, "checkpoints", f"sybil_{i}.ckpt")
114
- if os.path.exists(checkpoint_path):
115
- model = SybilForRiskPrediction(self.config)
116
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
117
-
118
- # Extract state dict
119
- if 'state_dict' in checkpoint:
120
- state_dict = checkpoint['state_dict']
121
- else:
122
- state_dict = checkpoint
123
-
124
- # Remove 'model.' prefix if present
125
- cleaned_state_dict = {}
126
- for k, v in state_dict.items():
127
- if k.startswith('model.'):
128
- cleaned_state_dict[k[6:]] = v
129
- else:
130
- cleaned_state_dict[k] = v
131
-
132
- model.load_state_dict(cleaned_state_dict, strict=False)
133
- model.to(self.device)
134
- model.eval()
135
- models.append(model)
136
-
137
- if not models:
138
- raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.")
139
-
140
- print(f"Loaded {len(models)} models in ensemble")
141
- return models
142
-
143
- def _apply_calibration(self, scores: np.ndarray) -> np.ndarray:
144
- """
145
- Apply calibration to raw model outputs.
146
-
147
- Args:
148
- scores: Raw risk scores from the model
149
-
150
- Returns:
151
- Calibrated risk scores
152
- """
153
- if not self.calibrator:
154
- return scores
155
-
156
- calibrated = np.zeros_like(scores)
157
-
158
- for year in range(scores.shape[1]):
159
- year_key = f"Year{year + 1}"
160
- if year_key in self.calibrator:
161
- cal_data = self.calibrator[year_key]
162
- if isinstance(cal_data, list) and len(cal_data) > 0:
163
- cal_data = cal_data[0]
164
-
165
- # Apply linear calibration if available
166
- if isinstance(cal_data, dict) and "coef" in cal_data and "intercept" in cal_data:
167
- coef = cal_data["coef"][0][0] if isinstance(cal_data["coef"], list) else cal_data["coef"]
168
- intercept = cal_data["intercept"][0] if isinstance(cal_data["intercept"], list) else cal_data["intercept"]
169
-
170
- # Apply calibration
171
- calibrated[:, year] = scores[:, year] * coef + intercept
172
- calibrated[:, year] = 1 / (1 + np.exp(-calibrated[:, year])) # Sigmoid
173
- else:
174
- calibrated[:, year] = scores[:, year]
175
- else:
176
- calibrated[:, year] = scores[:, year]
177
-
178
- return calibrated
179
-
180
- def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor:
181
- """
182
- Preprocess DICOM files for model input.
183
-
184
- Args:
185
- dicom_paths: List of paths to DICOM files
186
-
187
- Returns:
188
- Preprocessed tensor ready for model input
189
- """
190
- # Use the image processor to handle DICOM files
191
- result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt")
192
- pixel_values = result["pixel_values"]
193
-
194
- # Ensure we have 5D tensor (B, C, D, H, W)
195
- if pixel_values.ndim == 4:
196
- pixel_values = pixel_values.unsqueeze(0) # Add batch dimension
197
-
198
- return pixel_values.to(self.device)
199
-
200
- def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput:
201
- """
202
- Run prediction on a CT scan series.
203
-
204
- Args:
205
- dicom_paths: List of paths to DICOM files for a single CT series
206
- return_attentions: Whether to return attention maps
207
-
208
- Returns:
209
- SybilOutput with risk scores and optional attention maps
210
- """
211
- # Preprocess the DICOM files
212
- pixel_values = self.preprocess_dicom(dicom_paths)
213
-
214
- # Run inference with ensemble
215
- all_predictions = []
216
- all_attentions = []
217
-
218
- with torch.no_grad():
219
- for model in self.models:
220
- output = model(
221
- pixel_values=pixel_values,
222
- return_attentions=return_attentions
223
- )
224
-
225
- # Extract risk scores
226
- if hasattr(output, 'risk_scores'):
227
- predictions = output.risk_scores
228
- else:
229
- predictions = output[0] if isinstance(output, tuple) else output
230
-
231
- all_predictions.append(predictions.cpu().numpy())
232
-
233
- if return_attentions and hasattr(output, 'image_attention'):
234
- all_attentions.append(output.image_attention)
235
-
236
- # Average ensemble predictions
237
- ensemble_pred = np.mean(all_predictions, axis=0)
238
-
239
- # Apply calibration
240
- calibrated_pred = self._apply_calibration(ensemble_pred)
241
-
242
- # Convert back to torch tensor
243
- risk_scores = torch.from_numpy(calibrated_pred).float()
244
-
245
- # Average attentions if requested
246
- attentions = None
247
- if return_attentions and all_attentions:
248
- attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)}
249
-
250
- return SybilOutput(risk_scores=risk_scores, attentions=attentions)
251
-
252
- def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput:
253
- """
254
- Convenience method for prediction.
255
-
256
- Args:
257
- dicom_paths: List of DICOM file paths for a single series
258
- dicom_series: List of lists of DICOM paths for batch processing
259
- **kwargs: Additional arguments passed to predict()
260
-
261
- Returns:
262
- SybilOutput with predictions
263
- """
264
- if dicom_series is not None:
265
- # Batch processing
266
- all_outputs = []
267
- for paths in dicom_series:
268
- output = self.predict(paths, **kwargs)
269
- all_outputs.append(output.risk_scores)
270
-
271
- risk_scores = torch.stack(all_outputs)
272
- return SybilOutput(risk_scores=risk_scores)
273
- elif dicom_paths is not None:
274
- return self.predict(dicom_paths, **kwargs)
275
- else:
276
- raise ValueError("Either dicom_paths or dicom_series must be provided")
277
-
278
- @classmethod
279
- def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
280
- """
281
- Load model from Hugging Face hub or local path.
282
-
283
- Args:
284
- pretrained_model_name_or_path: HF model ID or local path
285
- **kwargs: Additional configuration arguments
286
-
287
- Returns:
288
- SybilHFWrapper instance
289
- """
290
- # Load configuration
291
- config = kwargs.pop("config", None)
292
- if config is None:
293
- try:
294
- config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
295
- except:
296
- config = SybilConfig()
297
-
298
- return cls(config=config)