Aakash-Tripathi commited on
Commit
91ebb5b
·
verified ·
1 Parent(s): a0f3988

Delete modeling_sybil.py

Browse files
Files changed (1) hide show
  1. modeling_sybil.py +0 -385
modeling_sybil.py DELETED
@@ -1,385 +0,0 @@
1
- """PyTorch Sybil model for lung cancer risk prediction"""
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torchvision
6
- from transformers import PreTrainedModel
7
- from transformers.modeling_outputs import BaseModelOutput
8
- from typing import Optional, Dict, List, Tuple
9
- import numpy as np
10
- from dataclasses import dataclass
11
-
12
- try:
13
- from .configuration_sybil import SybilConfig
14
- except ImportError:
15
- from configuration_sybil import SybilConfig
16
-
17
-
18
- @dataclass
19
- class SybilOutput(BaseModelOutput):
20
- """
21
- Base class for Sybil model outputs.
22
-
23
- Args:
24
- risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`):
25
- Predicted risk scores for each year up to max_followup.
26
- image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*):
27
- Attention weights over image pixels.
28
- volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*):
29
- Attention weights over CT scan slices.
30
- hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*):
31
- Hidden states from the pooling layer.
32
- """
33
- risk_scores: torch.FloatTensor = None
34
- image_attention: Optional[torch.FloatTensor] = None
35
- volume_attention: Optional[torch.FloatTensor] = None
36
- hidden_states: Optional[torch.FloatTensor] = None
37
-
38
-
39
- class CumulativeProbabilityLayer(nn.Module):
40
- """Cumulative probability layer for survival prediction"""
41
-
42
- def __init__(self, hidden_dim: int, max_followup: int = 6):
43
- super().__init__()
44
- self.max_followup = max_followup
45
- self.fc = nn.Linear(hidden_dim, max_followup)
46
-
47
- def forward(self, x):
48
- logits = self.fc(x)
49
- # Apply cumulative sum for monotonic risk scores
50
- cumsum = torch.cumsum(torch.sigmoid(logits), dim=-1)
51
- # Normalize to [0, 1] range
52
- return cumsum / self.max_followup
53
-
54
-
55
- class MultiAttentionPool(nn.Module):
56
- """Multi-attention pooling layer for CT scan aggregation"""
57
-
58
- def __init__(self, channels: int = 512):
59
- super().__init__()
60
- self.channels = channels
61
-
62
- # Volume-level attention (across slices)
63
- self.volume_attention = nn.Sequential(
64
- nn.Conv3d(channels, 128, kernel_size=1),
65
- nn.ReLU(),
66
- nn.Conv3d(128, 1, kernel_size=1)
67
- )
68
-
69
- # Image-level attention (within slices)
70
- self.image_attention = nn.Sequential(
71
- nn.Conv3d(channels, 128, kernel_size=1),
72
- nn.ReLU(),
73
- nn.Conv3d(128, 1, kernel_size=1)
74
- )
75
-
76
- def forward(self, x):
77
- batch_size = x.shape[0]
78
-
79
- # Compute attention weights
80
- volume_att = self.volume_attention(x) # [B, 1, D, H, W]
81
- image_att = self.image_attention(x) # [B, 1, D, H, W]
82
-
83
- # Apply softmax for normalization
84
- volume_att_flat = volume_att.view(batch_size, -1)
85
- volume_att_weights = torch.softmax(volume_att_flat, dim=-1)
86
- volume_att_weights = volume_att_weights.view_as(volume_att)
87
-
88
- image_att_2d = image_att.squeeze(1) # [B, D, H, W]
89
- for i in range(image_att_2d.shape[1]): # For each slice
90
- slice_att = image_att_2d[:, i, :, :].contiguous()
91
- slice_att_flat = slice_att.view(batch_size, -1)
92
- slice_att_weights = torch.softmax(slice_att_flat, dim=-1)
93
- image_att_2d[:, i, :, :] = slice_att_weights.view_as(slice_att)
94
- image_att = image_att_2d.unsqueeze(1)
95
-
96
- # Apply attention and pool
97
- attended = x * volume_att_weights * image_att
98
- hidden = attended.mean(dim=[2, 3, 4]) # Global average pooling
99
-
100
- return {
101
- 'hidden': hidden,
102
- 'volume_attention_1': volume_att_weights.squeeze(1),
103
- 'image_attention_1': image_att.squeeze(1)
104
- }
105
-
106
-
107
- class SybilPreTrainedModel(PreTrainedModel):
108
- """
109
- An abstract class to handle weights initialization and a simple interface
110
- for downloading and loading pretrained models.
111
- """
112
- config_class = SybilConfig
113
- base_model_prefix = "sybil"
114
- supports_gradient_checkpointing = False
115
-
116
- def _init_weights(self, module):
117
- """Initialize the weights"""
118
- if isinstance(module, nn.Linear):
119
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
120
- if module.bias is not None:
121
- module.bias.data.zero_()
122
- elif isinstance(module, nn.Conv3d):
123
- nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
124
- if module.bias is not None:
125
- module.bias.data.zero_()
126
-
127
-
128
- class SybilForRiskPrediction(SybilPreTrainedModel):
129
- """
130
- Sybil model for lung cancer risk prediction from CT scans.
131
-
132
- This model takes 3D CT scan volumes as input and predicts cancer risk scores
133
- for multiple future time points (typically 1-6 years).
134
- """
135
-
136
- def __init__(self, config: SybilConfig):
137
- super().__init__(config)
138
- self.config = config
139
-
140
- # Use pretrained R3D-18 as backbone
141
- encoder = torchvision.models.video.r3d_18(pretrained=True)
142
- self.image_encoder = nn.Sequential(*list(encoder.children())[:-2])
143
-
144
- # Multi-attention pooling
145
- self.pool = MultiAttentionPool(channels=512)
146
-
147
- # Classification layers
148
- self.relu = nn.ReLU(inplace=False)
149
- self.dropout = nn.Dropout(p=config.dropout)
150
-
151
- # Risk prediction layer
152
- self.prob_of_failure_layer = CumulativeProbabilityLayer(
153
- config.hidden_dim,
154
- max_followup=config.max_followup
155
- )
156
-
157
- # Calibrator for ensemble predictions
158
- self.calibrator = None
159
- if config.calibrator_data:
160
- self.set_calibrator(config.calibrator_data)
161
-
162
- # Initialize weights
163
- self.post_init()
164
-
165
- def set_calibrator(self, calibrator_data: Dict):
166
- """Set calibration data for risk score adjustment"""
167
- self.calibrator = calibrator_data
168
-
169
- def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor:
170
- """Apply calibration to raw risk scores"""
171
- if self.calibrator is None:
172
- return scores
173
-
174
- # Convert to numpy for calibration
175
- scores_np = scores.detach().cpu().numpy()
176
- calibrated = np.zeros_like(scores_np)
177
-
178
- # Apply calibration for each year
179
- for year in range(scores_np.shape[1]):
180
- year_key = f"Year{year + 1}"
181
- if year_key in self.calibrator:
182
- # Apply calibration transformation
183
- calibrated[:, year] = self._apply_calibration(
184
- scores_np[:, year],
185
- self.calibrator[year_key]
186
- )
187
- else:
188
- calibrated[:, year] = scores_np[:, year]
189
-
190
- return torch.from_numpy(calibrated).to(scores.device)
191
-
192
- def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray:
193
- """Apply specific calibration transformation"""
194
- # Simplified calibration - in practice, this would use the full calibration model
195
- # from the original Sybil implementation
196
- return scores # Placeholder for now
197
-
198
- def forward(
199
- self,
200
- pixel_values: torch.FloatTensor,
201
- return_attentions: bool = False,
202
- return_dict: bool = True,
203
- ) -> SybilOutput:
204
- """
205
- Forward pass of the Sybil model.
206
-
207
- Args:
208
- pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`):
209
- Pixel values of CT scan volumes.
210
- return_attentions: (`bool`, *optional*, defaults to `False`):
211
- Whether to return attention weights.
212
- return_dict: (`bool`, *optional*, defaults to `True`):
213
- Whether to return a `SybilOutput` instead of a plain tuple.
214
-
215
- Returns:
216
- `SybilOutput` or tuple
217
- """
218
- # Extract features using 3D CNN backbone
219
- features = self.image_encoder(pixel_values)
220
-
221
- # Apply multi-attention pooling
222
- pool_output = self.pool(features)
223
-
224
- # Apply ReLU and dropout
225
- hidden = self.relu(pool_output['hidden'])
226
- hidden = self.dropout(hidden)
227
-
228
- # Predict risk scores
229
- risk_logits = self.prob_of_failure_layer(hidden)
230
- risk_scores = torch.sigmoid(risk_logits)
231
-
232
- # Apply calibration if available
233
- risk_scores = self._calibrate_scores(risk_scores)
234
-
235
- if not return_dict:
236
- outputs = (risk_scores,)
237
- if return_attentions:
238
- outputs = outputs + (pool_output.get('image_attention_1'),
239
- pool_output.get('volume_attention_1'))
240
- return outputs
241
-
242
- return SybilOutput(
243
- risk_scores=risk_scores,
244
- image_attention=pool_output.get('image_attention_1') if return_attentions else None,
245
- volume_attention=pool_output.get('volume_attention_1') if return_attentions else None,
246
- hidden_states=hidden if return_attentions else None
247
- )
248
-
249
- @classmethod
250
- def from_pretrained_ensemble(
251
- cls,
252
- pretrained_model_name_or_path,
253
- checkpoint_paths: List[str],
254
- calibrator_path: Optional[str] = None,
255
- **kwargs
256
- ):
257
- """
258
- Load an ensemble of Sybil models from checkpoints.
259
-
260
- Args:
261
- pretrained_model_name_or_path: Path to the pretrained model or model identifier.
262
- checkpoint_paths: List of paths to individual model checkpoints.
263
- calibrator_path: Path to calibration data.
264
- **kwargs: Additional keyword arguments for model initialization.
265
-
266
- Returns:
267
- SybilEnsemble: An ensemble of Sybil models.
268
- """
269
- config = kwargs.pop("config", None)
270
- if config is None:
271
- config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
272
-
273
- # Load calibrator if provided
274
- calibrator_data = None
275
- if calibrator_path:
276
- import json
277
- with open(calibrator_path, 'r') as f:
278
- calibrator_data = json.load(f)
279
- config.calibrator_data = calibrator_data
280
-
281
- # Create ensemble
282
- models = []
283
- for checkpoint_path in checkpoint_paths:
284
- model = cls(config)
285
- # Load checkpoint weights
286
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
287
- # Remove 'model.' prefix from state dict keys if present
288
- state_dict = {}
289
- for k, v in checkpoint['state_dict'].items():
290
- if k.startswith('model.'):
291
- state_dict[k[6:]] = v
292
- else:
293
- state_dict[k] = v
294
-
295
- # Map to new model structure
296
- mapped_state_dict = model._map_checkpoint_weights(state_dict)
297
- model.load_state_dict(mapped_state_dict, strict=False)
298
- models.append(model)
299
-
300
- return SybilEnsemble(models, config)
301
-
302
- def _map_checkpoint_weights(self, state_dict: Dict) -> Dict:
303
- """Map original Sybil checkpoint weights to new structure"""
304
- mapped = {}
305
-
306
- # Map encoder weights
307
- for k, v in state_dict.items():
308
- if k.startswith('image_encoder'):
309
- mapped[k] = v
310
- elif k.startswith('pool'):
311
- # Map pooling layer weights
312
- mapped[k] = v
313
- elif k.startswith('prob_of_failure_layer'):
314
- # Map final prediction layer
315
- mapped[k] = v
316
-
317
- return mapped
318
-
319
-
320
- class SybilEnsemble:
321
- """Ensemble of Sybil models for improved predictions"""
322
-
323
- def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig):
324
- self.models = models
325
- self.config = config
326
- self.device = None
327
-
328
- def to(self, device):
329
- """Move all models to device"""
330
- self.device = device
331
- for model in self.models:
332
- model.to(device)
333
- return self
334
-
335
- def eval(self):
336
- """Set all models to evaluation mode"""
337
- for model in self.models:
338
- model.eval()
339
-
340
- def __call__(
341
- self,
342
- pixel_values: torch.FloatTensor,
343
- return_attentions: bool = False,
344
- ) -> SybilOutput:
345
- """
346
- Run inference with ensemble voting.
347
-
348
- Args:
349
- pixel_values: Input CT scan volumes.
350
- return_attentions: Whether to return attention maps.
351
-
352
- Returns:
353
- SybilOutput with averaged predictions from all models.
354
- """
355
- all_risk_scores = []
356
- all_image_attentions = []
357
- all_volume_attentions = []
358
-
359
- with torch.no_grad():
360
- for model in self.models:
361
- output = model(
362
- pixel_values=pixel_values,
363
- return_attentions=return_attentions
364
- )
365
- all_risk_scores.append(output.risk_scores)
366
-
367
- if return_attentions:
368
- all_image_attentions.append(output.image_attention)
369
- all_volume_attentions.append(output.volume_attention)
370
-
371
- # Average predictions
372
- risk_scores = torch.stack(all_risk_scores).mean(dim=0)
373
-
374
- # Average attentions if requested
375
- image_attention = None
376
- volume_attention = None
377
- if return_attentions:
378
- image_attention = torch.stack(all_image_attentions).mean(dim=0)
379
- volume_attention = torch.stack(all_volume_attentions).mean(dim=0)
380
-
381
- return SybilOutput(
382
- risk_scores=risk_scores,
383
- image_attention=image_attention,
384
- volume_attention=volume_attention
385
- )