ravimohan19 commited on
Commit
b82edbd
·
verified ·
1 Parent(s): 1bcef47

Upload models/physics_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/physics_model.py +152 -0
models/physics_model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Physics model wrappers for use as GP mean functions and standalone surrogates."""
2
+
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import gpytorch
8
+ from gpytorch.means import Mean
9
+
10
+ from physics_informed_bo.models.base import SurrogateModel
11
+
12
+
13
+ class PhysicsMeanFunction(Mean):
14
+ """Wraps a user-defined physics function as a GPyTorch mean function.
15
+
16
+ The physics function becomes the prior mean of the GP, so the GP
17
+ only needs to learn the residual (discrepancy) between physics
18
+ predictions and actual observations.
19
+
20
+ Example:
21
+ def arrhenius(X):
22
+ # X[:, 0] = temperature, X[:, 1] = activation energy
23
+ T, Ea = X[:, 0], X[:, 1]
24
+ R = 8.314
25
+ return torch.log(1e13 * torch.exp(-Ea / (R * T)))
26
+
27
+ mean_fn = PhysicsMeanFunction(arrhenius)
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ physics_fn: Callable[[Tensor], Tensor],
33
+ output_scale: float = 1.0,
34
+ learnable_scale: bool = True,
35
+ ):
36
+ super().__init__()
37
+ self.physics_fn = physics_fn
38
+ if learnable_scale:
39
+ self.register_parameter(
40
+ "raw_output_scale",
41
+ torch.nn.Parameter(torch.tensor(output_scale)),
42
+ )
43
+ else:
44
+ self.register_buffer(
45
+ "raw_output_scale", torch.tensor(output_scale)
46
+ )
47
+
48
+ @property
49
+ def output_scale(self) -> Tensor:
50
+ return self.raw_output_scale
51
+
52
+ def forward(self, X: Tensor) -> Tensor:
53
+ """Evaluate the physics model at X and scale the output."""
54
+ physics_pred = self.physics_fn(X)
55
+ return self.output_scale * physics_pred
56
+
57
+
58
+ class PhysicsModel(SurrogateModel):
59
+ """Standalone physics model as a surrogate (no GP, deterministic).
60
+
61
+ Useful as a baseline or when no experimental data is available yet.
62
+ Uncertainty is estimated via input perturbation or a fixed noise level.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ physics_fn: Callable[[Tensor], Tensor],
68
+ noise_std: float = 0.1,
69
+ param_uncertainty: Optional[Dict[str, float]] = None,
70
+ ):
71
+ """
72
+ Args:
73
+ physics_fn: Callable that takes (n, d) tensor and returns (n,) tensor.
74
+ noise_std: Assumed observation noise standard deviation.
75
+ param_uncertainty: Dict mapping parameter names to their uncertainty
76
+ for propagating uncertainty through the physics model.
77
+ """
78
+ self.physics_fn = physics_fn
79
+ self.noise_std = noise_std
80
+ self.param_uncertainty = param_uncertainty or {}
81
+ self._train_X = None
82
+ self._train_y = None
83
+
84
+ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
85
+ """Predict using the physics model with estimated uncertainty."""
86
+ with torch.no_grad():
87
+ mean = self.physics_fn(X).unsqueeze(-1)
88
+
89
+ # Estimate uncertainty via local sensitivity + noise
90
+ variance = self._estimate_variance(X)
91
+ return mean, variance
92
+
93
+ def _estimate_variance(self, X: Tensor) -> Tensor:
94
+ """Estimate predictive variance via finite-difference sensitivity analysis."""
95
+ eps = 1e-4
96
+ n, d = X.shape
97
+ var = torch.full((n, 1), self.noise_std**2, dtype=X.dtype, device=X.device)
98
+
99
+ # Add sensitivity-based uncertainty if we have param uncertainties
100
+ if self.param_uncertainty:
101
+ X_perturbed = X.clone().requires_grad_(True)
102
+ f = self.physics_fn(X_perturbed)
103
+ for i in range(d):
104
+ X_plus = X.clone()
105
+ X_plus[:, i] += eps
106
+ X_minus = X.clone()
107
+ X_minus[:, i] -= eps
108
+ df_dx = (self.physics_fn(X_plus) - self.physics_fn(X_minus)) / (2 * eps)
109
+ param_name = f"x{i}"
110
+ if param_name in self.param_uncertainty:
111
+ var[:, 0] += (df_dx * self.param_uncertainty[param_name]) ** 2
112
+
113
+ return var
114
+
115
+ def fit(self, X: Tensor, y: Tensor) -> None:
116
+ """Store observations (physics model is not fitted, but we track data)."""
117
+ self._train_X = X
118
+ self._train_y = y
119
+
120
+ # Update noise estimate from residuals if we have data
121
+ if X is not None and y is not None:
122
+ with torch.no_grad():
123
+ pred = self.physics_fn(X)
124
+ residuals = y.squeeze() - pred
125
+ self.noise_std = float(residuals.std())
126
+
127
+ def posterior(self, X: Tensor):
128
+ """Return a simple posterior-like object for BoTorch compatibility."""
129
+ mean, var = self.predict(X)
130
+ return _SimplePosterior(mean, var)
131
+
132
+
133
+ class _SimplePosterior:
134
+ """Minimal posterior wrapper for BoTorch compatibility."""
135
+
136
+ def __init__(self, mean: Tensor, variance: Tensor):
137
+ self._mean = mean
138
+ self._variance = variance
139
+
140
+ @property
141
+ def mean(self) -> Tensor:
142
+ return self._mean
143
+
144
+ @property
145
+ def variance(self) -> Tensor:
146
+ return self._variance
147
+
148
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
149
+ """Draw reparameterized samples from the posterior."""
150
+ std = self._variance.sqrt()
151
+ eps = torch.randn(*sample_shape, *self._mean.shape, device=self._mean.device)
152
+ return self._mean + std * eps