Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
8dd961f
1
Parent(s):
07fb119
FEAT(modern_ml_architecture): implement comprehensive transformer-based architecture for polymer analysis with multi-task learning and uncertainty estimation
Browse filesAdds transformer-based architecture for polymer analysis
Implements a comprehensive modern machine learning architecture utilizing transformer models for polymer analysis, incorporating multi-task learning capabilities and uncertainty estimation.
Enhancements include structured prediction outputs, ensemble model integration, and improved training frameworks to facilitate better classification, regression, and uncertainty quantification.
Relates to enhancing model robustness and accuracy in polymer material predictions.
modules/modern_ml_architecture.py
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modern ML Architecture for POLYMEROS
|
| 3 |
+
Implements transformer-based models, multi-task learning, and ensemble methods
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from typing import Dict, List, Tuple, Optional, Union, Any
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
| 15 |
+
from sklearn.metrics import accuracy_score, mean_squared_error
|
| 16 |
+
import xgboost as xgb
|
| 17 |
+
from scipy import stats
|
| 18 |
+
import warnings
|
| 19 |
+
import json
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ModelPrediction:
|
| 25 |
+
"""Structured prediction output with uncertainty quantification"""
|
| 26 |
+
|
| 27 |
+
prediction: Union[int, float, np.ndarray]
|
| 28 |
+
confidence: float
|
| 29 |
+
uncertainty_epistemic: float # Model uncertainty
|
| 30 |
+
uncertainty_aleatoric: float # Data uncertainty
|
| 31 |
+
class_probabilities: Optional[np.ndarray] = None
|
| 32 |
+
feature_importance: Optional[Dict[str, float]] = None
|
| 33 |
+
explanation: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class MultiTaskTarget:
|
| 38 |
+
"""Multi-task learning targets"""
|
| 39 |
+
|
| 40 |
+
classification_target: Optional[int] = None # Polymer type classification
|
| 41 |
+
degradation_level: Optional[float] = None # Continuous degradation score
|
| 42 |
+
property_predictions: Optional[Dict[str, float]] = None # Material properties
|
| 43 |
+
aging_rate: Optional[float] = None # Rate of aging prediction
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SpectralTransformerBlock(nn.Module):
|
| 47 |
+
"""Transformer block optimized for spectral data"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.d_model = d_model
|
| 52 |
+
self.num_heads = num_heads
|
| 53 |
+
|
| 54 |
+
# Multi-head attention
|
| 55 |
+
self.attention = nn.MultiheadAttention(
|
| 56 |
+
d_model, num_heads, dropout=dropout, batch_first=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Feed-forward network
|
| 60 |
+
self.ff_network = nn.Sequential(
|
| 61 |
+
nn.Linear(d_model, d_ff),
|
| 62 |
+
nn.ReLU(),
|
| 63 |
+
nn.Dropout(dropout),
|
| 64 |
+
nn.Linear(d_ff, d_model),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Layer normalization
|
| 68 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 69 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 70 |
+
|
| 71 |
+
# Dropout
|
| 72 |
+
self.dropout = nn.Dropout(dropout)
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
# Self-attention with residual connection
|
| 78 |
+
attn_output, attention_weights = self.attention(x, x, x, attn_mask=mask)
|
| 79 |
+
x = self.ln1(x + self.dropout(attn_output))
|
| 80 |
+
|
| 81 |
+
# Feed-forward with residual connection
|
| 82 |
+
ff_output = self.ff_network(x)
|
| 83 |
+
x = self.ln2(x + self.dropout(ff_output))
|
| 84 |
+
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class SpectralPositionalEncoding(nn.Module):
|
| 89 |
+
"""Positional encoding adapted for spectral wavenumber information"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, d_model: int, max_seq_length: int = 2000):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.d_model = d_model
|
| 94 |
+
|
| 95 |
+
# Create positional encoding matrix
|
| 96 |
+
pe = torch.zeros(max_seq_length, d_model)
|
| 97 |
+
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
| 98 |
+
|
| 99 |
+
# Use different frequencies for different dimensions
|
| 100 |
+
div_term = torch.exp(
|
| 101 |
+
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 105 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 106 |
+
|
| 107 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
seq_len = x.size(1)
|
| 111 |
+
return x + self.pe[:, :seq_len, :].to(x.device)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SpectralTransformer(nn.Module):
|
| 115 |
+
"""Transformer architecture optimized for spectral analysis"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
input_dim: int = 1,
|
| 120 |
+
d_model: int = 256,
|
| 121 |
+
num_heads: int = 8,
|
| 122 |
+
num_layers: int = 6,
|
| 123 |
+
d_ff: int = 1024,
|
| 124 |
+
max_seq_length: int = 2000,
|
| 125 |
+
num_classes: int = 2,
|
| 126 |
+
dropout: float = 0.1,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
|
| 130 |
+
self.d_model = d_model
|
| 131 |
+
self.num_classes = num_classes
|
| 132 |
+
|
| 133 |
+
# Input projection
|
| 134 |
+
self.input_projection = nn.Linear(input_dim, d_model)
|
| 135 |
+
|
| 136 |
+
# Positional encoding
|
| 137 |
+
self.pos_encoding = SpectralPositionalEncoding(d_model, max_seq_length)
|
| 138 |
+
|
| 139 |
+
# Transformer layers
|
| 140 |
+
self.transformer_layers = nn.ModuleList(
|
| 141 |
+
[
|
| 142 |
+
SpectralTransformerBlock(d_model, num_heads, d_ff, dropout)
|
| 143 |
+
for _ in range(num_layers)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Classification head
|
| 148 |
+
self.classification_head = nn.Sequential(
|
| 149 |
+
nn.Linear(d_model, d_model // 2),
|
| 150 |
+
nn.ReLU(),
|
| 151 |
+
nn.Dropout(dropout),
|
| 152 |
+
nn.Linear(d_model // 2, num_classes),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Regression heads for multi-task learning
|
| 156 |
+
self.degradation_head = nn.Sequential(
|
| 157 |
+
nn.Linear(d_model, d_model // 2),
|
| 158 |
+
nn.ReLU(),
|
| 159 |
+
nn.Dropout(dropout),
|
| 160 |
+
nn.Linear(d_model // 2, 1),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self.property_head = nn.Sequential(
|
| 164 |
+
nn.Linear(d_model, d_model // 2),
|
| 165 |
+
nn.ReLU(),
|
| 166 |
+
nn.Dropout(dropout),
|
| 167 |
+
nn.Linear(d_model // 2, 5), # Predict 5 material properties
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Uncertainty estimation layers
|
| 171 |
+
self.uncertainty_head = nn.Sequential(
|
| 172 |
+
nn.Linear(d_model, d_model // 4),
|
| 173 |
+
nn.ReLU(),
|
| 174 |
+
nn.Linear(d_model // 4, 2), # Epistemic and aleatoric uncertainty
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Attention pooling for sequence aggregation
|
| 178 |
+
self.attention_pool = nn.MultiheadAttention(d_model, 1, batch_first=True)
|
| 179 |
+
self.pool_query = nn.Parameter(torch.randn(1, 1, d_model))
|
| 180 |
+
|
| 181 |
+
self.dropout = nn.Dropout(dropout)
|
| 182 |
+
|
| 183 |
+
def forward(
|
| 184 |
+
self, x: torch.Tensor, return_attention: bool = False
|
| 185 |
+
) -> Dict[str, torch.Tensor]:
|
| 186 |
+
batch_size, seq_len, input_dim = x.shape
|
| 187 |
+
|
| 188 |
+
# Input projection and positional encoding
|
| 189 |
+
x = self.input_projection(x) # (batch, seq_len, d_model)
|
| 190 |
+
x = self.pos_encoding(x)
|
| 191 |
+
x = self.dropout(x)
|
| 192 |
+
|
| 193 |
+
# Store attention weights if requested
|
| 194 |
+
attention_weights = []
|
| 195 |
+
|
| 196 |
+
# Pass through transformer layers
|
| 197 |
+
for layer in self.transformer_layers:
|
| 198 |
+
x = layer(x)
|
| 199 |
+
|
| 200 |
+
# Attention pooling to get sequence representation
|
| 201 |
+
query = self.pool_query.expand(batch_size, -1, -1)
|
| 202 |
+
pooled_output, pool_attention = self.attention_pool(query, x, x)
|
| 203 |
+
pooled_output = pooled_output.squeeze(1) # (batch, d_model)
|
| 204 |
+
|
| 205 |
+
if return_attention:
|
| 206 |
+
attention_weights.append(pool_attention)
|
| 207 |
+
|
| 208 |
+
# Multi-task outputs
|
| 209 |
+
outputs = {}
|
| 210 |
+
|
| 211 |
+
# Classification output
|
| 212 |
+
classification_logits = self.classification_head(pooled_output)
|
| 213 |
+
outputs["classification_logits"] = classification_logits
|
| 214 |
+
outputs["classification_probs"] = F.softmax(classification_logits, dim=-1)
|
| 215 |
+
|
| 216 |
+
# Degradation prediction
|
| 217 |
+
degradation_pred = self.degradation_head(pooled_output)
|
| 218 |
+
outputs["degradation_prediction"] = degradation_pred
|
| 219 |
+
|
| 220 |
+
# Property predictions
|
| 221 |
+
property_pred = self.property_head(pooled_output)
|
| 222 |
+
outputs["property_predictions"] = property_pred
|
| 223 |
+
|
| 224 |
+
# Uncertainty estimation
|
| 225 |
+
uncertainty_pred = self.uncertainty_head(pooled_output)
|
| 226 |
+
outputs["uncertainty_epistemic"] = torch.nn.Softplus()(uncertainty_pred[:, 0])
|
| 227 |
+
outputs["uncertainty_aleatoric"] = F.softplus(uncertainty_pred[:, 1])
|
| 228 |
+
|
| 229 |
+
if return_attention:
|
| 230 |
+
outputs["attention_weights"] = attention_weights
|
| 231 |
+
|
| 232 |
+
return outputs
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class BayesianUncertaintyEstimator:
|
| 236 |
+
"""Bayesian uncertainty quantification using Monte Carlo dropout"""
|
| 237 |
+
|
| 238 |
+
def __init__(self, model: nn.Module, num_samples: int = 100):
|
| 239 |
+
self.model = model
|
| 240 |
+
self.num_samples = num_samples
|
| 241 |
+
|
| 242 |
+
def enable_dropout(self, model: nn.Module):
|
| 243 |
+
"""Enable dropout for uncertainty estimation"""
|
| 244 |
+
for module in model.modules():
|
| 245 |
+
if isinstance(module, nn.Dropout):
|
| 246 |
+
module.train()
|
| 247 |
+
|
| 248 |
+
def predict_with_uncertainty(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 249 |
+
"""
|
| 250 |
+
Predict with uncertainty quantification using Monte Carlo dropout
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
x: Input tensor
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Predictions with uncertainty estimates
|
| 257 |
+
"""
|
| 258 |
+
self.model.eval()
|
| 259 |
+
self.enable_dropout(self.model)
|
| 260 |
+
|
| 261 |
+
predictions = []
|
| 262 |
+
classification_probs = []
|
| 263 |
+
degradation_preds = []
|
| 264 |
+
uncertainty_estimates = []
|
| 265 |
+
|
| 266 |
+
with torch.no_grad():
|
| 267 |
+
for _ in range(self.num_samples):
|
| 268 |
+
output = self.model(x)
|
| 269 |
+
predictions.append(output["classification_probs"])
|
| 270 |
+
classification_probs.append(output["classification_probs"])
|
| 271 |
+
degradation_preds.append(output["degradation_prediction"])
|
| 272 |
+
uncertainty_estimates.append(
|
| 273 |
+
torch.stack(
|
| 274 |
+
[
|
| 275 |
+
output["uncertainty_epistemic"],
|
| 276 |
+
output["uncertainty_aleatoric"],
|
| 277 |
+
],
|
| 278 |
+
dim=1,
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Stack predictions
|
| 283 |
+
classification_stack = torch.stack(
|
| 284 |
+
classification_probs, dim=0
|
| 285 |
+
) # (num_samples, batch, classes)
|
| 286 |
+
degradation_stack = torch.stack(degradation_preds, dim=0)
|
| 287 |
+
uncertainty_stack = torch.stack(uncertainty_estimates, dim=0)
|
| 288 |
+
|
| 289 |
+
# Calculate statistics
|
| 290 |
+
mean_classification = classification_stack.mean(dim=0)
|
| 291 |
+
std_classification = classification_stack.std(dim=0)
|
| 292 |
+
|
| 293 |
+
mean_degradation = degradation_stack.mean(dim=0)
|
| 294 |
+
std_degradation = degradation_stack.std(dim=0)
|
| 295 |
+
|
| 296 |
+
mean_uncertainty = uncertainty_stack.mean(dim=0)
|
| 297 |
+
|
| 298 |
+
# Calculate epistemic uncertainty (model uncertainty)
|
| 299 |
+
epistemic_uncertainty = std_classification.mean(dim=1)
|
| 300 |
+
|
| 301 |
+
# Calculate aleatoric uncertainty (data uncertainty)
|
| 302 |
+
aleatoric_uncertainty = mean_uncertainty[:, 1]
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
"mean_classification": mean_classification,
|
| 306 |
+
"std_classification": std_classification,
|
| 307 |
+
"mean_degradation": mean_degradation,
|
| 308 |
+
"std_degradation": std_degradation,
|
| 309 |
+
"epistemic_uncertainty": epistemic_uncertainty,
|
| 310 |
+
"aleatoric_uncertainty": aleatoric_uncertainty,
|
| 311 |
+
"total_uncertainty": epistemic_uncertainty + aleatoric_uncertainty,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class EnsembleModel:
|
| 316 |
+
"""Ensemble model combining multiple approaches"""
|
| 317 |
+
|
| 318 |
+
def __init__(self):
|
| 319 |
+
self.models = {}
|
| 320 |
+
self.weights = {}
|
| 321 |
+
self.is_fitted = False
|
| 322 |
+
|
| 323 |
+
def add_transformer_model(self, model: SpectralTransformer, weight: float = 1.0):
|
| 324 |
+
"""Add transformer model to ensemble"""
|
| 325 |
+
self.models["transformer"] = model
|
| 326 |
+
self.weights["transformer"] = weight
|
| 327 |
+
|
| 328 |
+
def add_random_forest(self, n_estimators: int = 100, weight: float = 1.0):
|
| 329 |
+
"""Add Random Forest to ensemble"""
|
| 330 |
+
self.models["random_forest_clf"] = RandomForestClassifier(
|
| 331 |
+
n_estimators=n_estimators, random_state=42, oob_score=True
|
| 332 |
+
)
|
| 333 |
+
self.models["random_forest_reg"] = RandomForestRegressor(
|
| 334 |
+
n_estimators=n_estimators, random_state=42, oob_score=True
|
| 335 |
+
)
|
| 336 |
+
self.weights["random_forest"] = weight
|
| 337 |
+
|
| 338 |
+
def add_xgboost(self, weight: float = 1.0):
|
| 339 |
+
"""Add XGBoost to ensemble"""
|
| 340 |
+
self.models["xgboost_clf"] = xgb.XGBClassifier(
|
| 341 |
+
n_estimators=100, random_state=42, eval_metric="logloss"
|
| 342 |
+
)
|
| 343 |
+
self.models["xgboost_reg"] = xgb.XGBRegressor(n_estimators=100, random_state=42)
|
| 344 |
+
self.weights["xgboost"] = weight
|
| 345 |
+
|
| 346 |
+
def fit(
|
| 347 |
+
self,
|
| 348 |
+
X: np.ndarray,
|
| 349 |
+
y_classification: np.ndarray,
|
| 350 |
+
y_degradation: Optional[np.ndarray] = None,
|
| 351 |
+
):
|
| 352 |
+
"""
|
| 353 |
+
Fit ensemble models
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
X: Input features (flattened spectra for traditional ML models)
|
| 357 |
+
y_classification: Classification targets
|
| 358 |
+
y_degradation: Degradation targets (optional)
|
| 359 |
+
"""
|
| 360 |
+
# Fit Random Forest
|
| 361 |
+
if "random_forest_clf" in self.models:
|
| 362 |
+
self.models["random_forest_clf"].fit(X, y_classification)
|
| 363 |
+
if y_degradation is not None:
|
| 364 |
+
self.models["random_forest_reg"].fit(X, y_degradation)
|
| 365 |
+
|
| 366 |
+
# Fit XGBoost
|
| 367 |
+
if "xgboost_clf" in self.models:
|
| 368 |
+
self.models["xgboost_clf"].fit(X, y_classification)
|
| 369 |
+
if y_degradation is not None:
|
| 370 |
+
self.models["xgboost_reg"].fit(X, y_degradation)
|
| 371 |
+
|
| 372 |
+
self.is_fitted = True
|
| 373 |
+
|
| 374 |
+
def predict(
|
| 375 |
+
self, X: np.ndarray, X_transformer: Optional[torch.Tensor] = None
|
| 376 |
+
) -> ModelPrediction:
|
| 377 |
+
"""
|
| 378 |
+
Ensemble prediction with uncertainty quantification
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
X: Input features for traditional ML models
|
| 382 |
+
X_transformer: Input tensor for transformer model
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Ensemble prediction with uncertainty
|
| 386 |
+
"""
|
| 387 |
+
if not self.is_fitted and "transformer" not in self.models:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
"Ensemble must be fitted or contain pre-trained transformer"
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
predictions = {}
|
| 393 |
+
classification_probs = []
|
| 394 |
+
degradation_preds = []
|
| 395 |
+
model_weights = []
|
| 396 |
+
|
| 397 |
+
# Random Forest predictions
|
| 398 |
+
if (
|
| 399 |
+
"random_forest_clf" in self.models
|
| 400 |
+
and self.models["random_forest_clf"] is not None
|
| 401 |
+
):
|
| 402 |
+
rf_probs = self.models["random_forest_clf"].predict_proba(X)
|
| 403 |
+
classification_probs.append(rf_probs)
|
| 404 |
+
model_weights.append(self.weights["random_forest"])
|
| 405 |
+
|
| 406 |
+
if "random_forest_reg" in self.models:
|
| 407 |
+
rf_degradation = self.models["random_forest_reg"].predict(X)
|
| 408 |
+
degradation_preds.append(rf_degradation)
|
| 409 |
+
|
| 410 |
+
# XGBoost predictions
|
| 411 |
+
if "xgboost_clf" in self.models and self.models["xgboost_clf"] is not None:
|
| 412 |
+
xgb_probs = self.models["xgboost_clf"].predict_proba(X)
|
| 413 |
+
classification_probs.append(xgb_probs)
|
| 414 |
+
model_weights.append(self.weights["xgboost"])
|
| 415 |
+
|
| 416 |
+
if "xgboost_reg" in self.models:
|
| 417 |
+
xgb_degradation = self.models["xgboost_reg"].predict(X)
|
| 418 |
+
degradation_preds.append(xgb_degradation)
|
| 419 |
+
|
| 420 |
+
# Transformer predictions
|
| 421 |
+
if "transformer" in self.models and X_transformer is not None:
|
| 422 |
+
transformer_output = self.models["transformer"](X_transformer)
|
| 423 |
+
transformer_probs = (
|
| 424 |
+
transformer_output["classification_probs"].detach().numpy()
|
| 425 |
+
)
|
| 426 |
+
classification_probs.append(transformer_probs)
|
| 427 |
+
model_weights.append(self.weights["transformer"])
|
| 428 |
+
|
| 429 |
+
transformer_degradation = (
|
| 430 |
+
transformer_output["degradation_prediction"].detach().numpy()
|
| 431 |
+
)
|
| 432 |
+
degradation_preds.append(transformer_degradation.flatten())
|
| 433 |
+
|
| 434 |
+
# Weighted ensemble
|
| 435 |
+
if classification_probs:
|
| 436 |
+
model_weights = np.array(model_weights)
|
| 437 |
+
model_weights = model_weights / np.sum(model_weights) # Normalize
|
| 438 |
+
|
| 439 |
+
# Weighted average of probabilities
|
| 440 |
+
ensemble_probs = np.zeros_like(classification_probs[0])
|
| 441 |
+
for i, probs in enumerate(classification_probs):
|
| 442 |
+
ensemble_probs += model_weights[i] * probs
|
| 443 |
+
|
| 444 |
+
# Predicted class
|
| 445 |
+
predicted_class = np.argmax(ensemble_probs, axis=1)[0]
|
| 446 |
+
confidence = np.max(ensemble_probs, axis=1)[0]
|
| 447 |
+
|
| 448 |
+
# Calculate uncertainty from model disagreement
|
| 449 |
+
prob_variance = np.var([probs[0] for probs in classification_probs], axis=0)
|
| 450 |
+
epistemic_uncertainty = np.mean(prob_variance)
|
| 451 |
+
|
| 452 |
+
# Aleatoric uncertainty (average across models)
|
| 453 |
+
aleatoric_uncertainty = 1.0 - confidence # Simple estimate
|
| 454 |
+
|
| 455 |
+
# Degradation prediction
|
| 456 |
+
ensemble_degradation = None
|
| 457 |
+
if degradation_preds:
|
| 458 |
+
ensemble_degradation = np.average(
|
| 459 |
+
degradation_preds, weights=model_weights, axis=0
|
| 460 |
+
)[0]
|
| 461 |
+
|
| 462 |
+
else:
|
| 463 |
+
raise ValueError("No valid predictions could be made")
|
| 464 |
+
|
| 465 |
+
# Feature importance (from Random Forest if available)
|
| 466 |
+
feature_importance = None
|
| 467 |
+
if (
|
| 468 |
+
"random_forest_clf" in self.models
|
| 469 |
+
and self.models["random_forest_clf"] is not None
|
| 470 |
+
):
|
| 471 |
+
importance = self.models["random_forest_clf"].feature_importances_
|
| 472 |
+
# Convert to wavenumber-based importance (assuming spectral input)
|
| 473 |
+
feature_importance = {
|
| 474 |
+
f"wavenumber_{i}": float(importance[i]) for i in range(len(importance))
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
return ModelPrediction(
|
| 478 |
+
prediction=predicted_class,
|
| 479 |
+
confidence=confidence,
|
| 480 |
+
uncertainty_epistemic=epistemic_uncertainty,
|
| 481 |
+
uncertainty_aleatoric=aleatoric_uncertainty,
|
| 482 |
+
class_probabilities=ensemble_probs[0],
|
| 483 |
+
feature_importance=feature_importance,
|
| 484 |
+
explanation=self._generate_explanation(
|
| 485 |
+
predicted_class, confidence, ensemble_degradation
|
| 486 |
+
),
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def _generate_explanation(
|
| 490 |
+
self,
|
| 491 |
+
predicted_class: int,
|
| 492 |
+
confidence: float,
|
| 493 |
+
degradation: Optional[float] = None,
|
| 494 |
+
) -> str:
|
| 495 |
+
"""Generate human-readable explanation"""
|
| 496 |
+
class_names = {0: "Stable (Unweathered)", 1: "Weathered"}
|
| 497 |
+
class_name = class_names.get(predicted_class, f"Class {predicted_class}")
|
| 498 |
+
|
| 499 |
+
explanation = f"Predicted class: {class_name} (confidence: {confidence:.3f})"
|
| 500 |
+
|
| 501 |
+
if degradation is not None:
|
| 502 |
+
explanation += f"\nEstimated degradation level: {degradation:.3f}"
|
| 503 |
+
|
| 504 |
+
if confidence > 0.8:
|
| 505 |
+
explanation += "\nHigh confidence prediction - strong spectral evidence"
|
| 506 |
+
elif confidence > 0.6:
|
| 507 |
+
explanation += "\nModerate confidence - some uncertainty in classification"
|
| 508 |
+
else:
|
| 509 |
+
explanation += "\nLow confidence - significant uncertainty, consider additional analysis"
|
| 510 |
+
|
| 511 |
+
return explanation
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class MultiTaskLearningFramework:
|
| 515 |
+
"""Framework for multi-task learning in polymer analysis"""
|
| 516 |
+
|
| 517 |
+
def __init__(self, model: SpectralTransformer):
|
| 518 |
+
self.model = model
|
| 519 |
+
self.task_weights = {
|
| 520 |
+
"classification": 1.0,
|
| 521 |
+
"degradation": 0.5,
|
| 522 |
+
"properties": 0.3,
|
| 523 |
+
}
|
| 524 |
+
self.optimizer = None
|
| 525 |
+
self.scheduler = None
|
| 526 |
+
|
| 527 |
+
def setup_training(self, learning_rate: float = 1e-4):
|
| 528 |
+
"""Setup optimizer and scheduler"""
|
| 529 |
+
self.optimizer = torch.optim.AdamW(
|
| 530 |
+
self.model.parameters(), lr=learning_rate, weight_decay=0.01
|
| 531 |
+
)
|
| 532 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 533 |
+
self.optimizer, T_max=100
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def compute_loss(
|
| 537 |
+
self,
|
| 538 |
+
outputs: Dict[str, torch.Tensor],
|
| 539 |
+
targets: MultiTaskTarget,
|
| 540 |
+
batch_size: int,
|
| 541 |
+
) -> Dict[str, torch.Tensor]:
|
| 542 |
+
"""
|
| 543 |
+
Compute multi-task loss
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
outputs: Model outputs
|
| 547 |
+
targets: Multi-task targets
|
| 548 |
+
batch_size: Batch size
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
Loss components
|
| 552 |
+
"""
|
| 553 |
+
losses = {}
|
| 554 |
+
total_loss = 0
|
| 555 |
+
|
| 556 |
+
# Classification loss
|
| 557 |
+
if targets.classification_target is not None:
|
| 558 |
+
classification_loss = F.cross_entropy(
|
| 559 |
+
outputs["classification_logits"],
|
| 560 |
+
torch.tensor(
|
| 561 |
+
[targets.classification_target] * batch_size, dtype=torch.long
|
| 562 |
+
),
|
| 563 |
+
)
|
| 564 |
+
losses["classification"] = classification_loss
|
| 565 |
+
total_loss += self.task_weights["classification"] * classification_loss
|
| 566 |
+
|
| 567 |
+
# Degradation regression loss
|
| 568 |
+
if targets.degradation_level is not None:
|
| 569 |
+
degradation_loss = F.mse_loss(
|
| 570 |
+
outputs["degradation_prediction"].squeeze(),
|
| 571 |
+
torch.tensor(
|
| 572 |
+
[targets.degradation_level] * batch_size, dtype=torch.float
|
| 573 |
+
),
|
| 574 |
+
)
|
| 575 |
+
losses["degradation"] = degradation_loss
|
| 576 |
+
total_loss += self.task_weights["degradation"] * degradation_loss
|
| 577 |
+
|
| 578 |
+
# Property prediction loss
|
| 579 |
+
if targets.property_predictions is not None:
|
| 580 |
+
property_targets = torch.tensor(
|
| 581 |
+
[[targets.property_predictions.get(f"prop_{i}", 0.0) for i in range(5)]]
|
| 582 |
+
* batch_size,
|
| 583 |
+
dtype=torch.float,
|
| 584 |
+
)
|
| 585 |
+
property_loss = F.mse_loss(
|
| 586 |
+
outputs["property_predictions"], property_targets
|
| 587 |
+
)
|
| 588 |
+
losses["properties"] = property_loss
|
| 589 |
+
total_loss += self.task_weights["properties"] * property_loss
|
| 590 |
+
|
| 591 |
+
# Uncertainty regularization
|
| 592 |
+
uncertainty_reg = torch.mean(outputs["uncertainty_epistemic"]) + torch.mean(
|
| 593 |
+
outputs["uncertainty_aleatoric"]
|
| 594 |
+
)
|
| 595 |
+
losses["uncertainty_reg"] = uncertainty_reg
|
| 596 |
+
total_loss += 0.01 * uncertainty_reg # Small weight for regularization
|
| 597 |
+
|
| 598 |
+
losses["total"] = total_loss
|
| 599 |
+
return losses
|
| 600 |
+
|
| 601 |
+
def train_step(self, x: torch.Tensor, targets: MultiTaskTarget) -> Dict[str, float]:
|
| 602 |
+
"""Single training step"""
|
| 603 |
+
self.model.train()
|
| 604 |
+
if self.optimizer is None:
|
| 605 |
+
raise ValueError(
|
| 606 |
+
"Optimizer is not initialized. Call setup_training() to initialize it."
|
| 607 |
+
)
|
| 608 |
+
self.optimizer.zero_grad()
|
| 609 |
+
|
| 610 |
+
outputs = self.model(x)
|
| 611 |
+
losses = self.compute_loss(outputs, targets, x.size(0))
|
| 612 |
+
|
| 613 |
+
losses["total"].backward()
|
| 614 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 615 |
+
if self.optimizer is None:
|
| 616 |
+
raise ValueError(
|
| 617 |
+
"Optimizer is not initialized. Call setup_training() to initialize it."
|
| 618 |
+
)
|
| 619 |
+
self.optimizer.step()
|
| 620 |
+
|
| 621 |
+
return {
|
| 622 |
+
k: float(v.item()) if torch.is_tensor(v) else float(v)
|
| 623 |
+
for k, v in losses.items()
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
class ModernMLPipeline:
|
| 628 |
+
"""Complete modern ML pipeline for polymer analysis"""
|
| 629 |
+
|
| 630 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 631 |
+
self.config = config or self._default_config()
|
| 632 |
+
self.transformer_model = None
|
| 633 |
+
self.ensemble_model = None
|
| 634 |
+
self.uncertainty_estimator = None
|
| 635 |
+
self.multi_task_framework = None
|
| 636 |
+
|
| 637 |
+
def _default_config(self) -> Dict:
|
| 638 |
+
"""Default configuration"""
|
| 639 |
+
return {
|
| 640 |
+
"transformer": {
|
| 641 |
+
"d_model": 256,
|
| 642 |
+
"num_heads": 8,
|
| 643 |
+
"num_layers": 6,
|
| 644 |
+
"d_ff": 1024,
|
| 645 |
+
"dropout": 0.1,
|
| 646 |
+
"num_classes": 2,
|
| 647 |
+
},
|
| 648 |
+
"ensemble": {
|
| 649 |
+
"transformer_weight": 0.4,
|
| 650 |
+
"random_forest_weight": 0.3,
|
| 651 |
+
"xgboost_weight": 0.3,
|
| 652 |
+
},
|
| 653 |
+
"uncertainty": {"num_mc_samples": 50},
|
| 654 |
+
"training": {"learning_rate": 1e-4, "batch_size": 32, "num_epochs": 100},
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
def initialize_models(self, input_dim: int = 1, max_seq_length: int = 2000):
|
| 658 |
+
"""Initialize all models"""
|
| 659 |
+
# Transformer model
|
| 660 |
+
self.transformer_model = SpectralTransformer(
|
| 661 |
+
input_dim=input_dim,
|
| 662 |
+
d_model=self.config["transformer"]["d_model"],
|
| 663 |
+
num_heads=self.config["transformer"]["num_heads"],
|
| 664 |
+
num_layers=self.config["transformer"]["num_layers"],
|
| 665 |
+
d_ff=self.config["transformer"]["d_ff"],
|
| 666 |
+
max_seq_length=max_seq_length,
|
| 667 |
+
num_classes=self.config["transformer"]["num_classes"],
|
| 668 |
+
dropout=self.config["transformer"]["dropout"],
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Uncertainty estimator
|
| 672 |
+
self.uncertainty_estimator = BayesianUncertaintyEstimator(
|
| 673 |
+
self.transformer_model,
|
| 674 |
+
num_samples=self.config["uncertainty"]["num_mc_samples"],
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Multi-task framework
|
| 678 |
+
self.multi_task_framework = MultiTaskLearningFramework(self.transformer_model)
|
| 679 |
+
|
| 680 |
+
# Ensemble model
|
| 681 |
+
self.ensemble_model = EnsembleModel()
|
| 682 |
+
self.ensemble_model.add_transformer_model(
|
| 683 |
+
self.transformer_model, self.config["ensemble"]["transformer_weight"]
|
| 684 |
+
)
|
| 685 |
+
self.ensemble_model.add_random_forest(
|
| 686 |
+
weight=self.config["ensemble"]["random_forest_weight"]
|
| 687 |
+
)
|
| 688 |
+
self.ensemble_model.add_xgboost(
|
| 689 |
+
weight=self.config["ensemble"]["xgboost_weight"]
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
def train_ensemble(
|
| 693 |
+
self,
|
| 694 |
+
X_flat: np.ndarray,
|
| 695 |
+
X_transformer: torch.Tensor,
|
| 696 |
+
y_classification: np.ndarray,
|
| 697 |
+
y_degradation: Optional[np.ndarray] = None,
|
| 698 |
+
):
|
| 699 |
+
"""Train the ensemble model"""
|
| 700 |
+
if self.ensemble_model is None:
|
| 701 |
+
raise ValueError("Models not initialized. Call initialize_models() first.")
|
| 702 |
+
|
| 703 |
+
# Train traditional ML models
|
| 704 |
+
self.ensemble_model.fit(X_flat, y_classification, y_degradation)
|
| 705 |
+
|
| 706 |
+
# Setup transformer training
|
| 707 |
+
if self.multi_task_framework is None:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
"Multi-task framework is not initialized. Call initialize_models() first."
|
| 710 |
+
)
|
| 711 |
+
self.multi_task_framework.setup_training(
|
| 712 |
+
self.config["training"]["learning_rate"]
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
print(
|
| 716 |
+
"Ensemble training completed (transformer training would require full training loop)"
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
def predict_with_all_methods(
|
| 720 |
+
self, X_flat: np.ndarray, X_transformer: torch.Tensor
|
| 721 |
+
) -> Dict[str, Any]:
|
| 722 |
+
"""
|
| 723 |
+
Comprehensive prediction using all methods
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
X_flat: Flattened spectral data for traditional ML
|
| 727 |
+
X_transformer: Tensor format for transformer
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
Complete prediction results
|
| 731 |
+
"""
|
| 732 |
+
results = {}
|
| 733 |
+
|
| 734 |
+
# Ensemble prediction
|
| 735 |
+
if self.ensemble_model is None:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
"Ensemble model is not initialized. Call initialize_models() first."
|
| 738 |
+
)
|
| 739 |
+
ensemble_pred = self.ensemble_model.predict(X_flat, X_transformer)
|
| 740 |
+
results["ensemble"] = ensemble_pred
|
| 741 |
+
|
| 742 |
+
# Transformer with uncertainty
|
| 743 |
+
if self.transformer_model is not None:
|
| 744 |
+
if self.uncertainty_estimator is None:
|
| 745 |
+
raise ValueError(
|
| 746 |
+
"Uncertainty estimator is not initialized. Call initialize_models() first."
|
| 747 |
+
)
|
| 748 |
+
uncertainty_pred = self.uncertainty_estimator.predict_with_uncertainty(
|
| 749 |
+
X_transformer
|
| 750 |
+
)
|
| 751 |
+
results["transformer_uncertainty"] = uncertainty_pred
|
| 752 |
+
|
| 753 |
+
# Individual model predictions for comparison
|
| 754 |
+
individual_predictions = {}
|
| 755 |
+
|
| 756 |
+
if (
|
| 757 |
+
self.ensemble_model is not None
|
| 758 |
+
and "random_forest_clf" in self.ensemble_model.models
|
| 759 |
+
):
|
| 760 |
+
rf_pred = self.ensemble_model.models["random_forest_clf"].predict_proba(
|
| 761 |
+
X_flat
|
| 762 |
+
)[0]
|
| 763 |
+
individual_predictions["random_forest"] = rf_pred
|
| 764 |
+
|
| 765 |
+
if "xgboost_clf" in self.ensemble_model.models:
|
| 766 |
+
xgb_pred = self.ensemble_model.models["xgboost_clf"].predict_proba(X_flat)[
|
| 767 |
+
0
|
| 768 |
+
]
|
| 769 |
+
individual_predictions["xgboost"] = xgb_pred
|
| 770 |
+
|
| 771 |
+
results["individual_models"] = individual_predictions
|
| 772 |
+
|
| 773 |
+
return results
|
| 774 |
+
|
| 775 |
+
def get_model_insights(
|
| 776 |
+
self, X_flat: np.ndarray, X_transformer: torch.Tensor
|
| 777 |
+
) -> Dict[str, Any]:
|
| 778 |
+
"""
|
| 779 |
+
Generate insights about model behavior and predictions
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
X_flat: Flattened spectral data
|
| 783 |
+
X_transformer: Transformer input format
|
| 784 |
+
|
| 785 |
+
Returns:
|
| 786 |
+
Model insights and explanations
|
| 787 |
+
"""
|
| 788 |
+
insights = {}
|
| 789 |
+
|
| 790 |
+
# Feature importance from Random Forest
|
| 791 |
+
if "random_forest_clf" in self.ensemble_model.models:
|
| 792 |
+
if (
|
| 793 |
+
self.ensemble_model
|
| 794 |
+
and "random_forest_clf" in self.ensemble_model.models
|
| 795 |
+
and self.ensemble_model.models["random_forest_clf"] is not None
|
| 796 |
+
):
|
| 797 |
+
rf_importance = self.ensemble_model.models[
|
| 798 |
+
"random_forest_clf"
|
| 799 |
+
].feature_importances_
|
| 800 |
+
else:
|
| 801 |
+
rf_importance = None
|
| 802 |
+
if rf_importance is not None:
|
| 803 |
+
top_features = np.argsort(rf_importance)[-10:][::-1]
|
| 804 |
+
else:
|
| 805 |
+
top_features = []
|
| 806 |
+
insights["top_spectral_regions"] = {
|
| 807 |
+
f"wavenumber_{idx}": float(rf_importance[idx])
|
| 808 |
+
for idx in top_features
|
| 809 |
+
if rf_importance is not None
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
# Attention weights from transformer
|
| 813 |
+
if self.transformer_model is not None:
|
| 814 |
+
self.transformer_model.eval()
|
| 815 |
+
with torch.no_grad():
|
| 816 |
+
outputs = self.transformer_model(X_transformer, return_attention=True)
|
| 817 |
+
if "attention_weights" in outputs:
|
| 818 |
+
insights["attention_patterns"] = outputs["attention_weights"]
|
| 819 |
+
|
| 820 |
+
# Uncertainty analysis
|
| 821 |
+
predictions = self.predict_with_all_methods(X_flat, X_transformer)
|
| 822 |
+
if "transformer_uncertainty" in predictions:
|
| 823 |
+
uncertainty_data = predictions["transformer_uncertainty"]
|
| 824 |
+
insights["uncertainty_analysis"] = {
|
| 825 |
+
"epistemic_uncertainty": float(
|
| 826 |
+
uncertainty_data["epistemic_uncertainty"].mean()
|
| 827 |
+
),
|
| 828 |
+
"aleatoric_uncertainty": float(
|
| 829 |
+
uncertainty_data["aleatoric_uncertainty"].mean()
|
| 830 |
+
),
|
| 831 |
+
"total_uncertainty": float(
|
| 832 |
+
uncertainty_data["total_uncertainty"].mean()
|
| 833 |
+
),
|
| 834 |
+
"confidence_level": (
|
| 835 |
+
"high"
|
| 836 |
+
if uncertainty_data["total_uncertainty"].mean() < 0.1
|
| 837 |
+
else (
|
| 838 |
+
"medium"
|
| 839 |
+
if uncertainty_data["total_uncertainty"].mean() < 0.3
|
| 840 |
+
else "low"
|
| 841 |
+
)
|
| 842 |
+
),
|
| 843 |
+
}
|
| 844 |
+
|
| 845 |
+
# Model agreement analysis
|
| 846 |
+
if "individual_models" in predictions:
|
| 847 |
+
individual = predictions["individual_models"]
|
| 848 |
+
agreements = []
|
| 849 |
+
for model1_name, model1_pred in individual.items():
|
| 850 |
+
for model2_name, model2_pred in individual.items():
|
| 851 |
+
if model1_name != model2_name:
|
| 852 |
+
# Calculate agreement based on prediction similarity
|
| 853 |
+
agreement = 1.0 - np.abs(model1_pred - model2_pred).mean()
|
| 854 |
+
agreements.append(agreement)
|
| 855 |
+
|
| 856 |
+
insights["model_agreement"] = {
|
| 857 |
+
"average_agreement": float(np.mean(agreements)) if agreements else 0.0,
|
| 858 |
+
"agreement_level": (
|
| 859 |
+
"high"
|
| 860 |
+
if np.mean(agreements) > 0.8
|
| 861 |
+
else "medium" if np.mean(agreements) > 0.6 else "low"
|
| 862 |
+
),
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
return insights
|
| 866 |
+
|
| 867 |
+
def save_models(self, save_path: Path):
|
| 868 |
+
"""Save trained models"""
|
| 869 |
+
save_path = Path(save_path)
|
| 870 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 871 |
+
|
| 872 |
+
# Save transformer model
|
| 873 |
+
if self.transformer_model is not None:
|
| 874 |
+
torch.save(
|
| 875 |
+
self.transformer_model.state_dict(), save_path / "transformer_model.pth"
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
# Save configuration
|
| 879 |
+
with open(save_path / "config.json", "w") as f:
|
| 880 |
+
json.dump(self.config, f, indent=2)
|
| 881 |
+
|
| 882 |
+
print(f"Models saved to {save_path}")
|
| 883 |
+
|
| 884 |
+
def load_models(self, load_path: Path):
|
| 885 |
+
"""Load pre-trained models"""
|
| 886 |
+
load_path = Path(load_path)
|
| 887 |
+
|
| 888 |
+
# Load configuration
|
| 889 |
+
with open(load_path / "config.json", "r") as f:
|
| 890 |
+
self.config = json.load(f)
|
| 891 |
+
|
| 892 |
+
# Initialize and load transformer
|
| 893 |
+
self.initialize_models()
|
| 894 |
+
if (
|
| 895 |
+
self.transformer_model is not None
|
| 896 |
+
and (load_path / "transformer_model.pth").exists()
|
| 897 |
+
):
|
| 898 |
+
self.transformer_model.load_state_dict(
|
| 899 |
+
torch.load(load_path / "transformer_model.pth", map_location="cpu")
|
| 900 |
+
)
|
| 901 |
+
else:
|
| 902 |
+
raise ValueError(
|
| 903 |
+
"Transformer model is not initialized or model file is missing."
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
print(f"Models loaded from {load_path}")
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
# Utility functions for data preparation
|
| 910 |
+
def prepare_transformer_input(
|
| 911 |
+
spectral_data: np.ndarray, max_length: int = 2000
|
| 912 |
+
) -> torch.Tensor:
|
| 913 |
+
"""
|
| 914 |
+
Prepare spectral data for transformer input
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
spectral_data: Raw spectral intensities (1D array)
|
| 918 |
+
max_length: Maximum sequence length
|
| 919 |
+
|
| 920 |
+
Returns:
|
| 921 |
+
Formatted tensor for transformer
|
| 922 |
+
"""
|
| 923 |
+
# Ensure proper length
|
| 924 |
+
if len(spectral_data) > max_length:
|
| 925 |
+
# Downsample
|
| 926 |
+
indices = np.linspace(0, len(spectral_data) - 1, max_length, dtype=int)
|
| 927 |
+
spectral_data = spectral_data[indices]
|
| 928 |
+
elif len(spectral_data) < max_length:
|
| 929 |
+
# Pad with zeros
|
| 930 |
+
padding = np.zeros(max_length - len(spectral_data))
|
| 931 |
+
spectral_data = np.concatenate([spectral_data, padding])
|
| 932 |
+
|
| 933 |
+
# Reshape for transformer: (batch_size, sequence_length, features)
|
| 934 |
+
return torch.tensor(spectral_data, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def create_multitask_targets(
|
| 938 |
+
classification_label: int,
|
| 939 |
+
degradation_score: Optional[float] = None,
|
| 940 |
+
material_properties: Optional[Dict[str, float]] = None,
|
| 941 |
+
) -> MultiTaskTarget:
|
| 942 |
+
"""
|
| 943 |
+
Create multi-task learning targets
|
| 944 |
+
|
| 945 |
+
Args:
|
| 946 |
+
classification_label: Classification target (0 or 1)
|
| 947 |
+
degradation_score: Continuous degradation score [0, 1]
|
| 948 |
+
material_properties: Dictionary of material properties
|
| 949 |
+
|
| 950 |
+
Returns:
|
| 951 |
+
MultiTaskTarget object
|
| 952 |
+
"""
|
| 953 |
+
return MultiTaskTarget(
|
| 954 |
+
classification_target=classification_label,
|
| 955 |
+
degradation_level=degradation_score,
|
| 956 |
+
property_predictions=material_properties,
|
| 957 |
+
)
|