devjas1 commited on
Commit
8dd961f
·
1 Parent(s): 07fb119

FEAT(modern_ml_architecture): implement comprehensive transformer-based architecture for polymer analysis with multi-task learning and uncertainty estimation

Browse files

Adds transformer-based architecture for polymer analysis

Implements a comprehensive modern machine learning architecture utilizing transformer models for polymer analysis, incorporating multi-task learning capabilities and uncertainty estimation.

Enhancements include structured prediction outputs, ensemble model integration, and improved training frameworks to facilitate better classification, regression, and uncertainty quantification.

Relates to enhancing model robustness and accuracy in polymer material predictions.

Files changed (1) hide show
  1. modules/modern_ml_architecture.py +957 -0
modules/modern_ml_architecture.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modern ML Architecture for POLYMEROS
3
+ Implements transformer-based models, multi-task learning, and ensemble methods
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import numpy as np
11
+ import pandas as pd
12
+ from typing import Dict, List, Tuple, Optional, Union, Any
13
+ from dataclasses import dataclass
14
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
15
+ from sklearn.metrics import accuracy_score, mean_squared_error
16
+ import xgboost as xgb
17
+ from scipy import stats
18
+ import warnings
19
+ import json
20
+ from pathlib import Path
21
+
22
+
23
+ @dataclass
24
+ class ModelPrediction:
25
+ """Structured prediction output with uncertainty quantification"""
26
+
27
+ prediction: Union[int, float, np.ndarray]
28
+ confidence: float
29
+ uncertainty_epistemic: float # Model uncertainty
30
+ uncertainty_aleatoric: float # Data uncertainty
31
+ class_probabilities: Optional[np.ndarray] = None
32
+ feature_importance: Optional[Dict[str, float]] = None
33
+ explanation: Optional[str] = None
34
+
35
+
36
+ @dataclass
37
+ class MultiTaskTarget:
38
+ """Multi-task learning targets"""
39
+
40
+ classification_target: Optional[int] = None # Polymer type classification
41
+ degradation_level: Optional[float] = None # Continuous degradation score
42
+ property_predictions: Optional[Dict[str, float]] = None # Material properties
43
+ aging_rate: Optional[float] = None # Rate of aging prediction
44
+
45
+
46
+ class SpectralTransformerBlock(nn.Module):
47
+ """Transformer block optimized for spectral data"""
48
+
49
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
50
+ super().__init__()
51
+ self.d_model = d_model
52
+ self.num_heads = num_heads
53
+
54
+ # Multi-head attention
55
+ self.attention = nn.MultiheadAttention(
56
+ d_model, num_heads, dropout=dropout, batch_first=True
57
+ )
58
+
59
+ # Feed-forward network
60
+ self.ff_network = nn.Sequential(
61
+ nn.Linear(d_model, d_ff),
62
+ nn.ReLU(),
63
+ nn.Dropout(dropout),
64
+ nn.Linear(d_ff, d_model),
65
+ )
66
+
67
+ # Layer normalization
68
+ self.ln1 = nn.LayerNorm(d_model)
69
+ self.ln2 = nn.LayerNorm(d_model)
70
+
71
+ # Dropout
72
+ self.dropout = nn.Dropout(dropout)
73
+
74
+ def forward(
75
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
76
+ ) -> torch.Tensor:
77
+ # Self-attention with residual connection
78
+ attn_output, attention_weights = self.attention(x, x, x, attn_mask=mask)
79
+ x = self.ln1(x + self.dropout(attn_output))
80
+
81
+ # Feed-forward with residual connection
82
+ ff_output = self.ff_network(x)
83
+ x = self.ln2(x + self.dropout(ff_output))
84
+
85
+ return x
86
+
87
+
88
+ class SpectralPositionalEncoding(nn.Module):
89
+ """Positional encoding adapted for spectral wavenumber information"""
90
+
91
+ def __init__(self, d_model: int, max_seq_length: int = 2000):
92
+ super().__init__()
93
+ self.d_model = d_model
94
+
95
+ # Create positional encoding matrix
96
+ pe = torch.zeros(max_seq_length, d_model)
97
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
98
+
99
+ # Use different frequencies for different dimensions
100
+ div_term = torch.exp(
101
+ torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
102
+ )
103
+
104
+ pe[:, 0::2] = torch.sin(position * div_term)
105
+ pe[:, 1::2] = torch.cos(position * div_term)
106
+
107
+ self.register_buffer("pe", pe.unsqueeze(0))
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ seq_len = x.size(1)
111
+ return x + self.pe[:, :seq_len, :].to(x.device)
112
+
113
+
114
+ class SpectralTransformer(nn.Module):
115
+ """Transformer architecture optimized for spectral analysis"""
116
+
117
+ def __init__(
118
+ self,
119
+ input_dim: int = 1,
120
+ d_model: int = 256,
121
+ num_heads: int = 8,
122
+ num_layers: int = 6,
123
+ d_ff: int = 1024,
124
+ max_seq_length: int = 2000,
125
+ num_classes: int = 2,
126
+ dropout: float = 0.1,
127
+ ):
128
+ super().__init__()
129
+
130
+ self.d_model = d_model
131
+ self.num_classes = num_classes
132
+
133
+ # Input projection
134
+ self.input_projection = nn.Linear(input_dim, d_model)
135
+
136
+ # Positional encoding
137
+ self.pos_encoding = SpectralPositionalEncoding(d_model, max_seq_length)
138
+
139
+ # Transformer layers
140
+ self.transformer_layers = nn.ModuleList(
141
+ [
142
+ SpectralTransformerBlock(d_model, num_heads, d_ff, dropout)
143
+ for _ in range(num_layers)
144
+ ]
145
+ )
146
+
147
+ # Classification head
148
+ self.classification_head = nn.Sequential(
149
+ nn.Linear(d_model, d_model // 2),
150
+ nn.ReLU(),
151
+ nn.Dropout(dropout),
152
+ nn.Linear(d_model // 2, num_classes),
153
+ )
154
+
155
+ # Regression heads for multi-task learning
156
+ self.degradation_head = nn.Sequential(
157
+ nn.Linear(d_model, d_model // 2),
158
+ nn.ReLU(),
159
+ nn.Dropout(dropout),
160
+ nn.Linear(d_model // 2, 1),
161
+ )
162
+
163
+ self.property_head = nn.Sequential(
164
+ nn.Linear(d_model, d_model // 2),
165
+ nn.ReLU(),
166
+ nn.Dropout(dropout),
167
+ nn.Linear(d_model // 2, 5), # Predict 5 material properties
168
+ )
169
+
170
+ # Uncertainty estimation layers
171
+ self.uncertainty_head = nn.Sequential(
172
+ nn.Linear(d_model, d_model // 4),
173
+ nn.ReLU(),
174
+ nn.Linear(d_model // 4, 2), # Epistemic and aleatoric uncertainty
175
+ )
176
+
177
+ # Attention pooling for sequence aggregation
178
+ self.attention_pool = nn.MultiheadAttention(d_model, 1, batch_first=True)
179
+ self.pool_query = nn.Parameter(torch.randn(1, 1, d_model))
180
+
181
+ self.dropout = nn.Dropout(dropout)
182
+
183
+ def forward(
184
+ self, x: torch.Tensor, return_attention: bool = False
185
+ ) -> Dict[str, torch.Tensor]:
186
+ batch_size, seq_len, input_dim = x.shape
187
+
188
+ # Input projection and positional encoding
189
+ x = self.input_projection(x) # (batch, seq_len, d_model)
190
+ x = self.pos_encoding(x)
191
+ x = self.dropout(x)
192
+
193
+ # Store attention weights if requested
194
+ attention_weights = []
195
+
196
+ # Pass through transformer layers
197
+ for layer in self.transformer_layers:
198
+ x = layer(x)
199
+
200
+ # Attention pooling to get sequence representation
201
+ query = self.pool_query.expand(batch_size, -1, -1)
202
+ pooled_output, pool_attention = self.attention_pool(query, x, x)
203
+ pooled_output = pooled_output.squeeze(1) # (batch, d_model)
204
+
205
+ if return_attention:
206
+ attention_weights.append(pool_attention)
207
+
208
+ # Multi-task outputs
209
+ outputs = {}
210
+
211
+ # Classification output
212
+ classification_logits = self.classification_head(pooled_output)
213
+ outputs["classification_logits"] = classification_logits
214
+ outputs["classification_probs"] = F.softmax(classification_logits, dim=-1)
215
+
216
+ # Degradation prediction
217
+ degradation_pred = self.degradation_head(pooled_output)
218
+ outputs["degradation_prediction"] = degradation_pred
219
+
220
+ # Property predictions
221
+ property_pred = self.property_head(pooled_output)
222
+ outputs["property_predictions"] = property_pred
223
+
224
+ # Uncertainty estimation
225
+ uncertainty_pred = self.uncertainty_head(pooled_output)
226
+ outputs["uncertainty_epistemic"] = torch.nn.Softplus()(uncertainty_pred[:, 0])
227
+ outputs["uncertainty_aleatoric"] = F.softplus(uncertainty_pred[:, 1])
228
+
229
+ if return_attention:
230
+ outputs["attention_weights"] = attention_weights
231
+
232
+ return outputs
233
+
234
+
235
+ class BayesianUncertaintyEstimator:
236
+ """Bayesian uncertainty quantification using Monte Carlo dropout"""
237
+
238
+ def __init__(self, model: nn.Module, num_samples: int = 100):
239
+ self.model = model
240
+ self.num_samples = num_samples
241
+
242
+ def enable_dropout(self, model: nn.Module):
243
+ """Enable dropout for uncertainty estimation"""
244
+ for module in model.modules():
245
+ if isinstance(module, nn.Dropout):
246
+ module.train()
247
+
248
+ def predict_with_uncertainty(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
249
+ """
250
+ Predict with uncertainty quantification using Monte Carlo dropout
251
+
252
+ Args:
253
+ x: Input tensor
254
+
255
+ Returns:
256
+ Predictions with uncertainty estimates
257
+ """
258
+ self.model.eval()
259
+ self.enable_dropout(self.model)
260
+
261
+ predictions = []
262
+ classification_probs = []
263
+ degradation_preds = []
264
+ uncertainty_estimates = []
265
+
266
+ with torch.no_grad():
267
+ for _ in range(self.num_samples):
268
+ output = self.model(x)
269
+ predictions.append(output["classification_probs"])
270
+ classification_probs.append(output["classification_probs"])
271
+ degradation_preds.append(output["degradation_prediction"])
272
+ uncertainty_estimates.append(
273
+ torch.stack(
274
+ [
275
+ output["uncertainty_epistemic"],
276
+ output["uncertainty_aleatoric"],
277
+ ],
278
+ dim=1,
279
+ )
280
+ )
281
+
282
+ # Stack predictions
283
+ classification_stack = torch.stack(
284
+ classification_probs, dim=0
285
+ ) # (num_samples, batch, classes)
286
+ degradation_stack = torch.stack(degradation_preds, dim=0)
287
+ uncertainty_stack = torch.stack(uncertainty_estimates, dim=0)
288
+
289
+ # Calculate statistics
290
+ mean_classification = classification_stack.mean(dim=0)
291
+ std_classification = classification_stack.std(dim=0)
292
+
293
+ mean_degradation = degradation_stack.mean(dim=0)
294
+ std_degradation = degradation_stack.std(dim=0)
295
+
296
+ mean_uncertainty = uncertainty_stack.mean(dim=0)
297
+
298
+ # Calculate epistemic uncertainty (model uncertainty)
299
+ epistemic_uncertainty = std_classification.mean(dim=1)
300
+
301
+ # Calculate aleatoric uncertainty (data uncertainty)
302
+ aleatoric_uncertainty = mean_uncertainty[:, 1]
303
+
304
+ return {
305
+ "mean_classification": mean_classification,
306
+ "std_classification": std_classification,
307
+ "mean_degradation": mean_degradation,
308
+ "std_degradation": std_degradation,
309
+ "epistemic_uncertainty": epistemic_uncertainty,
310
+ "aleatoric_uncertainty": aleatoric_uncertainty,
311
+ "total_uncertainty": epistemic_uncertainty + aleatoric_uncertainty,
312
+ }
313
+
314
+
315
+ class EnsembleModel:
316
+ """Ensemble model combining multiple approaches"""
317
+
318
+ def __init__(self):
319
+ self.models = {}
320
+ self.weights = {}
321
+ self.is_fitted = False
322
+
323
+ def add_transformer_model(self, model: SpectralTransformer, weight: float = 1.0):
324
+ """Add transformer model to ensemble"""
325
+ self.models["transformer"] = model
326
+ self.weights["transformer"] = weight
327
+
328
+ def add_random_forest(self, n_estimators: int = 100, weight: float = 1.0):
329
+ """Add Random Forest to ensemble"""
330
+ self.models["random_forest_clf"] = RandomForestClassifier(
331
+ n_estimators=n_estimators, random_state=42, oob_score=True
332
+ )
333
+ self.models["random_forest_reg"] = RandomForestRegressor(
334
+ n_estimators=n_estimators, random_state=42, oob_score=True
335
+ )
336
+ self.weights["random_forest"] = weight
337
+
338
+ def add_xgboost(self, weight: float = 1.0):
339
+ """Add XGBoost to ensemble"""
340
+ self.models["xgboost_clf"] = xgb.XGBClassifier(
341
+ n_estimators=100, random_state=42, eval_metric="logloss"
342
+ )
343
+ self.models["xgboost_reg"] = xgb.XGBRegressor(n_estimators=100, random_state=42)
344
+ self.weights["xgboost"] = weight
345
+
346
+ def fit(
347
+ self,
348
+ X: np.ndarray,
349
+ y_classification: np.ndarray,
350
+ y_degradation: Optional[np.ndarray] = None,
351
+ ):
352
+ """
353
+ Fit ensemble models
354
+
355
+ Args:
356
+ X: Input features (flattened spectra for traditional ML models)
357
+ y_classification: Classification targets
358
+ y_degradation: Degradation targets (optional)
359
+ """
360
+ # Fit Random Forest
361
+ if "random_forest_clf" in self.models:
362
+ self.models["random_forest_clf"].fit(X, y_classification)
363
+ if y_degradation is not None:
364
+ self.models["random_forest_reg"].fit(X, y_degradation)
365
+
366
+ # Fit XGBoost
367
+ if "xgboost_clf" in self.models:
368
+ self.models["xgboost_clf"].fit(X, y_classification)
369
+ if y_degradation is not None:
370
+ self.models["xgboost_reg"].fit(X, y_degradation)
371
+
372
+ self.is_fitted = True
373
+
374
+ def predict(
375
+ self, X: np.ndarray, X_transformer: Optional[torch.Tensor] = None
376
+ ) -> ModelPrediction:
377
+ """
378
+ Ensemble prediction with uncertainty quantification
379
+
380
+ Args:
381
+ X: Input features for traditional ML models
382
+ X_transformer: Input tensor for transformer model
383
+
384
+ Returns:
385
+ Ensemble prediction with uncertainty
386
+ """
387
+ if not self.is_fitted and "transformer" not in self.models:
388
+ raise ValueError(
389
+ "Ensemble must be fitted or contain pre-trained transformer"
390
+ )
391
+
392
+ predictions = {}
393
+ classification_probs = []
394
+ degradation_preds = []
395
+ model_weights = []
396
+
397
+ # Random Forest predictions
398
+ if (
399
+ "random_forest_clf" in self.models
400
+ and self.models["random_forest_clf"] is not None
401
+ ):
402
+ rf_probs = self.models["random_forest_clf"].predict_proba(X)
403
+ classification_probs.append(rf_probs)
404
+ model_weights.append(self.weights["random_forest"])
405
+
406
+ if "random_forest_reg" in self.models:
407
+ rf_degradation = self.models["random_forest_reg"].predict(X)
408
+ degradation_preds.append(rf_degradation)
409
+
410
+ # XGBoost predictions
411
+ if "xgboost_clf" in self.models and self.models["xgboost_clf"] is not None:
412
+ xgb_probs = self.models["xgboost_clf"].predict_proba(X)
413
+ classification_probs.append(xgb_probs)
414
+ model_weights.append(self.weights["xgboost"])
415
+
416
+ if "xgboost_reg" in self.models:
417
+ xgb_degradation = self.models["xgboost_reg"].predict(X)
418
+ degradation_preds.append(xgb_degradation)
419
+
420
+ # Transformer predictions
421
+ if "transformer" in self.models and X_transformer is not None:
422
+ transformer_output = self.models["transformer"](X_transformer)
423
+ transformer_probs = (
424
+ transformer_output["classification_probs"].detach().numpy()
425
+ )
426
+ classification_probs.append(transformer_probs)
427
+ model_weights.append(self.weights["transformer"])
428
+
429
+ transformer_degradation = (
430
+ transformer_output["degradation_prediction"].detach().numpy()
431
+ )
432
+ degradation_preds.append(transformer_degradation.flatten())
433
+
434
+ # Weighted ensemble
435
+ if classification_probs:
436
+ model_weights = np.array(model_weights)
437
+ model_weights = model_weights / np.sum(model_weights) # Normalize
438
+
439
+ # Weighted average of probabilities
440
+ ensemble_probs = np.zeros_like(classification_probs[0])
441
+ for i, probs in enumerate(classification_probs):
442
+ ensemble_probs += model_weights[i] * probs
443
+
444
+ # Predicted class
445
+ predicted_class = np.argmax(ensemble_probs, axis=1)[0]
446
+ confidence = np.max(ensemble_probs, axis=1)[0]
447
+
448
+ # Calculate uncertainty from model disagreement
449
+ prob_variance = np.var([probs[0] for probs in classification_probs], axis=0)
450
+ epistemic_uncertainty = np.mean(prob_variance)
451
+
452
+ # Aleatoric uncertainty (average across models)
453
+ aleatoric_uncertainty = 1.0 - confidence # Simple estimate
454
+
455
+ # Degradation prediction
456
+ ensemble_degradation = None
457
+ if degradation_preds:
458
+ ensemble_degradation = np.average(
459
+ degradation_preds, weights=model_weights, axis=0
460
+ )[0]
461
+
462
+ else:
463
+ raise ValueError("No valid predictions could be made")
464
+
465
+ # Feature importance (from Random Forest if available)
466
+ feature_importance = None
467
+ if (
468
+ "random_forest_clf" in self.models
469
+ and self.models["random_forest_clf"] is not None
470
+ ):
471
+ importance = self.models["random_forest_clf"].feature_importances_
472
+ # Convert to wavenumber-based importance (assuming spectral input)
473
+ feature_importance = {
474
+ f"wavenumber_{i}": float(importance[i]) for i in range(len(importance))
475
+ }
476
+
477
+ return ModelPrediction(
478
+ prediction=predicted_class,
479
+ confidence=confidence,
480
+ uncertainty_epistemic=epistemic_uncertainty,
481
+ uncertainty_aleatoric=aleatoric_uncertainty,
482
+ class_probabilities=ensemble_probs[0],
483
+ feature_importance=feature_importance,
484
+ explanation=self._generate_explanation(
485
+ predicted_class, confidence, ensemble_degradation
486
+ ),
487
+ )
488
+
489
+ def _generate_explanation(
490
+ self,
491
+ predicted_class: int,
492
+ confidence: float,
493
+ degradation: Optional[float] = None,
494
+ ) -> str:
495
+ """Generate human-readable explanation"""
496
+ class_names = {0: "Stable (Unweathered)", 1: "Weathered"}
497
+ class_name = class_names.get(predicted_class, f"Class {predicted_class}")
498
+
499
+ explanation = f"Predicted class: {class_name} (confidence: {confidence:.3f})"
500
+
501
+ if degradation is not None:
502
+ explanation += f"\nEstimated degradation level: {degradation:.3f}"
503
+
504
+ if confidence > 0.8:
505
+ explanation += "\nHigh confidence prediction - strong spectral evidence"
506
+ elif confidence > 0.6:
507
+ explanation += "\nModerate confidence - some uncertainty in classification"
508
+ else:
509
+ explanation += "\nLow confidence - significant uncertainty, consider additional analysis"
510
+
511
+ return explanation
512
+
513
+
514
+ class MultiTaskLearningFramework:
515
+ """Framework for multi-task learning in polymer analysis"""
516
+
517
+ def __init__(self, model: SpectralTransformer):
518
+ self.model = model
519
+ self.task_weights = {
520
+ "classification": 1.0,
521
+ "degradation": 0.5,
522
+ "properties": 0.3,
523
+ }
524
+ self.optimizer = None
525
+ self.scheduler = None
526
+
527
+ def setup_training(self, learning_rate: float = 1e-4):
528
+ """Setup optimizer and scheduler"""
529
+ self.optimizer = torch.optim.AdamW(
530
+ self.model.parameters(), lr=learning_rate, weight_decay=0.01
531
+ )
532
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
533
+ self.optimizer, T_max=100
534
+ )
535
+
536
+ def compute_loss(
537
+ self,
538
+ outputs: Dict[str, torch.Tensor],
539
+ targets: MultiTaskTarget,
540
+ batch_size: int,
541
+ ) -> Dict[str, torch.Tensor]:
542
+ """
543
+ Compute multi-task loss
544
+
545
+ Args:
546
+ outputs: Model outputs
547
+ targets: Multi-task targets
548
+ batch_size: Batch size
549
+
550
+ Returns:
551
+ Loss components
552
+ """
553
+ losses = {}
554
+ total_loss = 0
555
+
556
+ # Classification loss
557
+ if targets.classification_target is not None:
558
+ classification_loss = F.cross_entropy(
559
+ outputs["classification_logits"],
560
+ torch.tensor(
561
+ [targets.classification_target] * batch_size, dtype=torch.long
562
+ ),
563
+ )
564
+ losses["classification"] = classification_loss
565
+ total_loss += self.task_weights["classification"] * classification_loss
566
+
567
+ # Degradation regression loss
568
+ if targets.degradation_level is not None:
569
+ degradation_loss = F.mse_loss(
570
+ outputs["degradation_prediction"].squeeze(),
571
+ torch.tensor(
572
+ [targets.degradation_level] * batch_size, dtype=torch.float
573
+ ),
574
+ )
575
+ losses["degradation"] = degradation_loss
576
+ total_loss += self.task_weights["degradation"] * degradation_loss
577
+
578
+ # Property prediction loss
579
+ if targets.property_predictions is not None:
580
+ property_targets = torch.tensor(
581
+ [[targets.property_predictions.get(f"prop_{i}", 0.0) for i in range(5)]]
582
+ * batch_size,
583
+ dtype=torch.float,
584
+ )
585
+ property_loss = F.mse_loss(
586
+ outputs["property_predictions"], property_targets
587
+ )
588
+ losses["properties"] = property_loss
589
+ total_loss += self.task_weights["properties"] * property_loss
590
+
591
+ # Uncertainty regularization
592
+ uncertainty_reg = torch.mean(outputs["uncertainty_epistemic"]) + torch.mean(
593
+ outputs["uncertainty_aleatoric"]
594
+ )
595
+ losses["uncertainty_reg"] = uncertainty_reg
596
+ total_loss += 0.01 * uncertainty_reg # Small weight for regularization
597
+
598
+ losses["total"] = total_loss
599
+ return losses
600
+
601
+ def train_step(self, x: torch.Tensor, targets: MultiTaskTarget) -> Dict[str, float]:
602
+ """Single training step"""
603
+ self.model.train()
604
+ if self.optimizer is None:
605
+ raise ValueError(
606
+ "Optimizer is not initialized. Call setup_training() to initialize it."
607
+ )
608
+ self.optimizer.zero_grad()
609
+
610
+ outputs = self.model(x)
611
+ losses = self.compute_loss(outputs, targets, x.size(0))
612
+
613
+ losses["total"].backward()
614
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
615
+ if self.optimizer is None:
616
+ raise ValueError(
617
+ "Optimizer is not initialized. Call setup_training() to initialize it."
618
+ )
619
+ self.optimizer.step()
620
+
621
+ return {
622
+ k: float(v.item()) if torch.is_tensor(v) else float(v)
623
+ for k, v in losses.items()
624
+ }
625
+
626
+
627
+ class ModernMLPipeline:
628
+ """Complete modern ML pipeline for polymer analysis"""
629
+
630
+ def __init__(self, config: Optional[Dict] = None):
631
+ self.config = config or self._default_config()
632
+ self.transformer_model = None
633
+ self.ensemble_model = None
634
+ self.uncertainty_estimator = None
635
+ self.multi_task_framework = None
636
+
637
+ def _default_config(self) -> Dict:
638
+ """Default configuration"""
639
+ return {
640
+ "transformer": {
641
+ "d_model": 256,
642
+ "num_heads": 8,
643
+ "num_layers": 6,
644
+ "d_ff": 1024,
645
+ "dropout": 0.1,
646
+ "num_classes": 2,
647
+ },
648
+ "ensemble": {
649
+ "transformer_weight": 0.4,
650
+ "random_forest_weight": 0.3,
651
+ "xgboost_weight": 0.3,
652
+ },
653
+ "uncertainty": {"num_mc_samples": 50},
654
+ "training": {"learning_rate": 1e-4, "batch_size": 32, "num_epochs": 100},
655
+ }
656
+
657
+ def initialize_models(self, input_dim: int = 1, max_seq_length: int = 2000):
658
+ """Initialize all models"""
659
+ # Transformer model
660
+ self.transformer_model = SpectralTransformer(
661
+ input_dim=input_dim,
662
+ d_model=self.config["transformer"]["d_model"],
663
+ num_heads=self.config["transformer"]["num_heads"],
664
+ num_layers=self.config["transformer"]["num_layers"],
665
+ d_ff=self.config["transformer"]["d_ff"],
666
+ max_seq_length=max_seq_length,
667
+ num_classes=self.config["transformer"]["num_classes"],
668
+ dropout=self.config["transformer"]["dropout"],
669
+ )
670
+
671
+ # Uncertainty estimator
672
+ self.uncertainty_estimator = BayesianUncertaintyEstimator(
673
+ self.transformer_model,
674
+ num_samples=self.config["uncertainty"]["num_mc_samples"],
675
+ )
676
+
677
+ # Multi-task framework
678
+ self.multi_task_framework = MultiTaskLearningFramework(self.transformer_model)
679
+
680
+ # Ensemble model
681
+ self.ensemble_model = EnsembleModel()
682
+ self.ensemble_model.add_transformer_model(
683
+ self.transformer_model, self.config["ensemble"]["transformer_weight"]
684
+ )
685
+ self.ensemble_model.add_random_forest(
686
+ weight=self.config["ensemble"]["random_forest_weight"]
687
+ )
688
+ self.ensemble_model.add_xgboost(
689
+ weight=self.config["ensemble"]["xgboost_weight"]
690
+ )
691
+
692
+ def train_ensemble(
693
+ self,
694
+ X_flat: np.ndarray,
695
+ X_transformer: torch.Tensor,
696
+ y_classification: np.ndarray,
697
+ y_degradation: Optional[np.ndarray] = None,
698
+ ):
699
+ """Train the ensemble model"""
700
+ if self.ensemble_model is None:
701
+ raise ValueError("Models not initialized. Call initialize_models() first.")
702
+
703
+ # Train traditional ML models
704
+ self.ensemble_model.fit(X_flat, y_classification, y_degradation)
705
+
706
+ # Setup transformer training
707
+ if self.multi_task_framework is None:
708
+ raise ValueError(
709
+ "Multi-task framework is not initialized. Call initialize_models() first."
710
+ )
711
+ self.multi_task_framework.setup_training(
712
+ self.config["training"]["learning_rate"]
713
+ )
714
+
715
+ print(
716
+ "Ensemble training completed (transformer training would require full training loop)"
717
+ )
718
+
719
+ def predict_with_all_methods(
720
+ self, X_flat: np.ndarray, X_transformer: torch.Tensor
721
+ ) -> Dict[str, Any]:
722
+ """
723
+ Comprehensive prediction using all methods
724
+
725
+ Args:
726
+ X_flat: Flattened spectral data for traditional ML
727
+ X_transformer: Tensor format for transformer
728
+
729
+ Returns:
730
+ Complete prediction results
731
+ """
732
+ results = {}
733
+
734
+ # Ensemble prediction
735
+ if self.ensemble_model is None:
736
+ raise ValueError(
737
+ "Ensemble model is not initialized. Call initialize_models() first."
738
+ )
739
+ ensemble_pred = self.ensemble_model.predict(X_flat, X_transformer)
740
+ results["ensemble"] = ensemble_pred
741
+
742
+ # Transformer with uncertainty
743
+ if self.transformer_model is not None:
744
+ if self.uncertainty_estimator is None:
745
+ raise ValueError(
746
+ "Uncertainty estimator is not initialized. Call initialize_models() first."
747
+ )
748
+ uncertainty_pred = self.uncertainty_estimator.predict_with_uncertainty(
749
+ X_transformer
750
+ )
751
+ results["transformer_uncertainty"] = uncertainty_pred
752
+
753
+ # Individual model predictions for comparison
754
+ individual_predictions = {}
755
+
756
+ if (
757
+ self.ensemble_model is not None
758
+ and "random_forest_clf" in self.ensemble_model.models
759
+ ):
760
+ rf_pred = self.ensemble_model.models["random_forest_clf"].predict_proba(
761
+ X_flat
762
+ )[0]
763
+ individual_predictions["random_forest"] = rf_pred
764
+
765
+ if "xgboost_clf" in self.ensemble_model.models:
766
+ xgb_pred = self.ensemble_model.models["xgboost_clf"].predict_proba(X_flat)[
767
+ 0
768
+ ]
769
+ individual_predictions["xgboost"] = xgb_pred
770
+
771
+ results["individual_models"] = individual_predictions
772
+
773
+ return results
774
+
775
+ def get_model_insights(
776
+ self, X_flat: np.ndarray, X_transformer: torch.Tensor
777
+ ) -> Dict[str, Any]:
778
+ """
779
+ Generate insights about model behavior and predictions
780
+
781
+ Args:
782
+ X_flat: Flattened spectral data
783
+ X_transformer: Transformer input format
784
+
785
+ Returns:
786
+ Model insights and explanations
787
+ """
788
+ insights = {}
789
+
790
+ # Feature importance from Random Forest
791
+ if "random_forest_clf" in self.ensemble_model.models:
792
+ if (
793
+ self.ensemble_model
794
+ and "random_forest_clf" in self.ensemble_model.models
795
+ and self.ensemble_model.models["random_forest_clf"] is not None
796
+ ):
797
+ rf_importance = self.ensemble_model.models[
798
+ "random_forest_clf"
799
+ ].feature_importances_
800
+ else:
801
+ rf_importance = None
802
+ if rf_importance is not None:
803
+ top_features = np.argsort(rf_importance)[-10:][::-1]
804
+ else:
805
+ top_features = []
806
+ insights["top_spectral_regions"] = {
807
+ f"wavenumber_{idx}": float(rf_importance[idx])
808
+ for idx in top_features
809
+ if rf_importance is not None
810
+ }
811
+
812
+ # Attention weights from transformer
813
+ if self.transformer_model is not None:
814
+ self.transformer_model.eval()
815
+ with torch.no_grad():
816
+ outputs = self.transformer_model(X_transformer, return_attention=True)
817
+ if "attention_weights" in outputs:
818
+ insights["attention_patterns"] = outputs["attention_weights"]
819
+
820
+ # Uncertainty analysis
821
+ predictions = self.predict_with_all_methods(X_flat, X_transformer)
822
+ if "transformer_uncertainty" in predictions:
823
+ uncertainty_data = predictions["transformer_uncertainty"]
824
+ insights["uncertainty_analysis"] = {
825
+ "epistemic_uncertainty": float(
826
+ uncertainty_data["epistemic_uncertainty"].mean()
827
+ ),
828
+ "aleatoric_uncertainty": float(
829
+ uncertainty_data["aleatoric_uncertainty"].mean()
830
+ ),
831
+ "total_uncertainty": float(
832
+ uncertainty_data["total_uncertainty"].mean()
833
+ ),
834
+ "confidence_level": (
835
+ "high"
836
+ if uncertainty_data["total_uncertainty"].mean() < 0.1
837
+ else (
838
+ "medium"
839
+ if uncertainty_data["total_uncertainty"].mean() < 0.3
840
+ else "low"
841
+ )
842
+ ),
843
+ }
844
+
845
+ # Model agreement analysis
846
+ if "individual_models" in predictions:
847
+ individual = predictions["individual_models"]
848
+ agreements = []
849
+ for model1_name, model1_pred in individual.items():
850
+ for model2_name, model2_pred in individual.items():
851
+ if model1_name != model2_name:
852
+ # Calculate agreement based on prediction similarity
853
+ agreement = 1.0 - np.abs(model1_pred - model2_pred).mean()
854
+ agreements.append(agreement)
855
+
856
+ insights["model_agreement"] = {
857
+ "average_agreement": float(np.mean(agreements)) if agreements else 0.0,
858
+ "agreement_level": (
859
+ "high"
860
+ if np.mean(agreements) > 0.8
861
+ else "medium" if np.mean(agreements) > 0.6 else "low"
862
+ ),
863
+ }
864
+
865
+ return insights
866
+
867
+ def save_models(self, save_path: Path):
868
+ """Save trained models"""
869
+ save_path = Path(save_path)
870
+ save_path.mkdir(parents=True, exist_ok=True)
871
+
872
+ # Save transformer model
873
+ if self.transformer_model is not None:
874
+ torch.save(
875
+ self.transformer_model.state_dict(), save_path / "transformer_model.pth"
876
+ )
877
+
878
+ # Save configuration
879
+ with open(save_path / "config.json", "w") as f:
880
+ json.dump(self.config, f, indent=2)
881
+
882
+ print(f"Models saved to {save_path}")
883
+
884
+ def load_models(self, load_path: Path):
885
+ """Load pre-trained models"""
886
+ load_path = Path(load_path)
887
+
888
+ # Load configuration
889
+ with open(load_path / "config.json", "r") as f:
890
+ self.config = json.load(f)
891
+
892
+ # Initialize and load transformer
893
+ self.initialize_models()
894
+ if (
895
+ self.transformer_model is not None
896
+ and (load_path / "transformer_model.pth").exists()
897
+ ):
898
+ self.transformer_model.load_state_dict(
899
+ torch.load(load_path / "transformer_model.pth", map_location="cpu")
900
+ )
901
+ else:
902
+ raise ValueError(
903
+ "Transformer model is not initialized or model file is missing."
904
+ )
905
+
906
+ print(f"Models loaded from {load_path}")
907
+
908
+
909
+ # Utility functions for data preparation
910
+ def prepare_transformer_input(
911
+ spectral_data: np.ndarray, max_length: int = 2000
912
+ ) -> torch.Tensor:
913
+ """
914
+ Prepare spectral data for transformer input
915
+
916
+ Args:
917
+ spectral_data: Raw spectral intensities (1D array)
918
+ max_length: Maximum sequence length
919
+
920
+ Returns:
921
+ Formatted tensor for transformer
922
+ """
923
+ # Ensure proper length
924
+ if len(spectral_data) > max_length:
925
+ # Downsample
926
+ indices = np.linspace(0, len(spectral_data) - 1, max_length, dtype=int)
927
+ spectral_data = spectral_data[indices]
928
+ elif len(spectral_data) < max_length:
929
+ # Pad with zeros
930
+ padding = np.zeros(max_length - len(spectral_data))
931
+ spectral_data = np.concatenate([spectral_data, padding])
932
+
933
+ # Reshape for transformer: (batch_size, sequence_length, features)
934
+ return torch.tensor(spectral_data, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
935
+
936
+
937
+ def create_multitask_targets(
938
+ classification_label: int,
939
+ degradation_score: Optional[float] = None,
940
+ material_properties: Optional[Dict[str, float]] = None,
941
+ ) -> MultiTaskTarget:
942
+ """
943
+ Create multi-task learning targets
944
+
945
+ Args:
946
+ classification_label: Classification target (0 or 1)
947
+ degradation_score: Continuous degradation score [0, 1]
948
+ material_properties: Dictionary of material properties
949
+
950
+ Returns:
951
+ MultiTaskTarget object
952
+ """
953
+ return MultiTaskTarget(
954
+ classification_target=classification_label,
955
+ degradation_level=degradation_score,
956
+ property_predictions=material_properties,
957
+ )