Upload utils/diagnostics.py with huggingface_hub
Browse files- utils/diagnostics.py +109 -0
utils/diagnostics.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model diagnostics and validation utilities."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from physics_informed_bo.models.base import SurrogateModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def model_diagnostics(
|
| 13 |
+
surrogate: SurrogateModel,
|
| 14 |
+
X_test: Tensor,
|
| 15 |
+
y_test: Tensor,
|
| 16 |
+
) -> Dict:
|
| 17 |
+
"""Compute diagnostic metrics for the surrogate model.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
surrogate: Fitted surrogate model.
|
| 21 |
+
X_test: Test inputs (n, d).
|
| 22 |
+
y_test: Test targets (n, 1).
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Dict with RMSE, MAE, R2, NLPD, calibration metrics.
|
| 26 |
+
"""
|
| 27 |
+
mean, var = surrogate.predict(X_test)
|
| 28 |
+
mean = mean.squeeze()
|
| 29 |
+
var = var.squeeze()
|
| 30 |
+
y_test = y_test.squeeze()
|
| 31 |
+
|
| 32 |
+
residuals = y_test - mean
|
| 33 |
+
|
| 34 |
+
# Standard metrics
|
| 35 |
+
rmse = float((residuals**2).mean().sqrt())
|
| 36 |
+
mae = float(residuals.abs().mean())
|
| 37 |
+
ss_res = float((residuals**2).sum())
|
| 38 |
+
ss_tot = float(((y_test - y_test.mean()) ** 2).sum())
|
| 39 |
+
r2 = 1 - ss_res / (ss_tot + 1e-12)
|
| 40 |
+
|
| 41 |
+
# Negative Log Predictive Density
|
| 42 |
+
nlpd = float(
|
| 43 |
+
0.5 * (torch.log(2 * torch.pi * var) + residuals**2 / var).mean()
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Calibration: fraction of true values within predicted CI
|
| 47 |
+
std = var.sqrt()
|
| 48 |
+
in_1sigma = float(((mean - std <= y_test) & (y_test <= mean + std)).float().mean())
|
| 49 |
+
in_2sigma = float(((mean - 2 * std <= y_test) & (y_test <= mean + 2 * std)).float().mean())
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"rmse": rmse,
|
| 53 |
+
"mae": mae,
|
| 54 |
+
"r2": r2,
|
| 55 |
+
"nlpd": nlpd,
|
| 56 |
+
"calibration_1sigma": in_1sigma, # Ideal: ~0.68
|
| 57 |
+
"calibration_2sigma": in_2sigma, # Ideal: ~0.95
|
| 58 |
+
"mean_predicted_std": float(std.mean()),
|
| 59 |
+
"n_test": len(X_test),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def leave_one_out_cv(
|
| 64 |
+
surrogate_class,
|
| 65 |
+
surrogate_kwargs: Dict,
|
| 66 |
+
X: Tensor,
|
| 67 |
+
y: Tensor,
|
| 68 |
+
) -> Dict:
|
| 69 |
+
"""Perform leave-one-out cross-validation for the surrogate model.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
surrogate_class: Class of the surrogate model to evaluate.
|
| 73 |
+
surrogate_kwargs: Keyword arguments for the surrogate constructor.
|
| 74 |
+
X: Full dataset inputs (n, d).
|
| 75 |
+
y: Full dataset targets (n, 1).
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Dict with LOO-CV metrics.
|
| 79 |
+
"""
|
| 80 |
+
n = len(X)
|
| 81 |
+
predictions = torch.zeros(n)
|
| 82 |
+
variances = torch.zeros(n)
|
| 83 |
+
|
| 84 |
+
for i in range(n):
|
| 85 |
+
# Leave out point i
|
| 86 |
+
mask = torch.ones(n, dtype=torch.bool)
|
| 87 |
+
mask[i] = False
|
| 88 |
+
|
| 89 |
+
X_train = X[mask]
|
| 90 |
+
y_train = y[mask]
|
| 91 |
+
|
| 92 |
+
model = surrogate_class(**surrogate_kwargs)
|
| 93 |
+
model.fit(X_train, y_train)
|
| 94 |
+
|
| 95 |
+
mean_i, var_i = model.predict(X[i:i+1])
|
| 96 |
+
predictions[i] = mean_i.squeeze()
|
| 97 |
+
variances[i] = var_i.squeeze()
|
| 98 |
+
|
| 99 |
+
y_flat = y.squeeze()
|
| 100 |
+
residuals = y_flat - predictions
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"loo_rmse": float((residuals**2).mean().sqrt()),
|
| 104 |
+
"loo_mae": float(residuals.abs().mean()),
|
| 105 |
+
"loo_r2": float(1 - (residuals**2).sum() / ((y_flat - y_flat.mean()) ** 2).sum()),
|
| 106 |
+
"loo_nlpd": float(
|
| 107 |
+
0.5 * (torch.log(2 * torch.pi * variances) + residuals**2 / variances).mean()
|
| 108 |
+
),
|
| 109 |
+
}
|