ravimohan19 commited on
Commit
1bcef47
·
verified ·
1 Parent(s): f7040fb

Upload models/multi_fidelity.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/multi_fidelity.py +142 -0
models/multi_fidelity.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-fidelity model: physics model as low-fidelity, experiments as high-fidelity."""
2
+
3
+ from typing import Callable, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from botorch.models import SingleTaskMultiFidelityGP
8
+ from botorch.models.transforms.outcome import Standardize
9
+
10
+ from physics_informed_bo.models.base import SurrogateModel
11
+
12
+
13
+ class MultiFidelitySurrogate(SurrogateModel):
14
+ """Multi-fidelity BO model using physics as cheap low-fidelity source.
15
+
16
+ Uses BoTorch's SingleTaskMultiFidelityGP to jointly model:
17
+ - Low-fidelity: physics model predictions (cheap, approximate)
18
+ - High-fidelity: experimental observations (expensive, accurate)
19
+
20
+ The model learns the correlation between fidelities to transfer
21
+ knowledge from physics to improve experimental predictions.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ physics_fn: Callable[[Tensor], Tensor],
27
+ fidelity_dim: int = -1,
28
+ device: str = "cpu",
29
+ dtype: torch.dtype = torch.float64,
30
+ ):
31
+ """
32
+ Args:
33
+ physics_fn: Physics model function (low-fidelity source).
34
+ fidelity_dim: Column index for the fidelity indicator.
35
+ device: Torch device.
36
+ dtype: Torch dtype.
37
+ """
38
+ self.physics_fn = physics_fn
39
+ self.fidelity_dim = fidelity_dim
40
+ self.device = torch.device(device)
41
+ self.dtype = dtype
42
+ self._model = None
43
+
44
+ def build_multi_fidelity_data(
45
+ self,
46
+ X_experiment: Tensor,
47
+ y_experiment: Tensor,
48
+ X_physics_grid: Optional[Tensor] = None,
49
+ n_physics_points: int = 100,
50
+ ) -> Tuple[Tensor, Tensor]:
51
+ """Combine physics predictions and experimental data into multi-fidelity dataset.
52
+
53
+ Args:
54
+ X_experiment: Experimental inputs (n_exp, d).
55
+ y_experiment: Experimental outputs (n_exp, 1).
56
+ X_physics_grid: Optional grid for physics evaluations.
57
+ n_physics_points: Number of physics evaluation points if no grid given.
58
+
59
+ Returns:
60
+ X_mf: Combined inputs with fidelity column (n_total, d+1).
61
+ y_mf: Combined outputs (n_total, 1).
62
+ """
63
+ d = X_experiment.shape[-1]
64
+
65
+ # Generate physics data
66
+ if X_physics_grid is None:
67
+ X_physics_grid = torch.rand(
68
+ n_physics_points, d, device=self.device, dtype=self.dtype
69
+ )
70
+
71
+ with torch.no_grad():
72
+ y_physics = self.physics_fn(X_physics_grid).unsqueeze(-1)
73
+
74
+ # Add fidelity column: 0 = low (physics), 1 = high (experiment)
75
+ fidelity_low = torch.zeros(len(X_physics_grid), 1, device=self.device, dtype=self.dtype)
76
+ fidelity_high = torch.ones(len(X_experiment), 1, device=self.device, dtype=self.dtype)
77
+
78
+ X_physics_mf = torch.cat([X_physics_grid, fidelity_low], dim=-1)
79
+ X_experiment_mf = torch.cat([X_experiment, fidelity_high], dim=-1)
80
+
81
+ X_mf = torch.cat([X_physics_mf, X_experiment_mf], dim=0)
82
+ y_mf = torch.cat([y_physics, y_experiment], dim=0)
83
+
84
+ return X_mf, y_mf
85
+
86
+ def fit(
87
+ self,
88
+ X: Tensor,
89
+ y: Tensor,
90
+ training_iterations: int = 200,
91
+ lr: float = 0.05,
92
+ ) -> None:
93
+ """Fit the multi-fidelity GP model.
94
+
95
+ X should include a fidelity column (use build_multi_fidelity_data).
96
+ """
97
+ X = X.to(device=self.device, dtype=self.dtype)
98
+ y = y.to(device=self.device, dtype=self.dtype)
99
+ if y.dim() == 1:
100
+ y = y.unsqueeze(-1)
101
+
102
+ d = X.shape[-1]
103
+ fidelity_col = d - 1 if self.fidelity_dim == -1 else self.fidelity_dim
104
+
105
+ self._model = SingleTaskMultiFidelityGP(
106
+ train_X=X,
107
+ train_Y=y,
108
+ data_fidelities=[fidelity_col],
109
+ outcome_transform=Standardize(m=1),
110
+ ).to(device=self.device, dtype=self.dtype)
111
+
112
+ # Optimize hyperparameters
113
+ from botorch.fit import fit_gpytorch_mll
114
+ from gpytorch.mlls import ExactMarginalLogLikelihood
115
+
116
+ mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model)
117
+ fit_gpytorch_mll(mll)
118
+
119
+ def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
120
+ """Predict at high fidelity (fidelity=1)."""
121
+ X = X.to(device=self.device, dtype=self.dtype)
122
+
123
+ # Add high-fidelity indicator
124
+ if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
125
+ fidelity_col = torch.ones(len(X), 1, device=self.device, dtype=self.dtype)
126
+ X = torch.cat([X, fidelity_col], dim=-1)
127
+
128
+ posterior = self._model.posterior(X)
129
+ return posterior.mean, posterior.variance
130
+
131
+ def posterior(self, X: Tensor):
132
+ X = X.to(device=self.device, dtype=self.dtype)
133
+ if X.shape[-1] == self._model.train_inputs[0].shape[-1] - 1:
134
+ fidelity_col = torch.ones(
135
+ *X.shape[:-1], 1, device=self.device, dtype=self.dtype
136
+ )
137
+ X = torch.cat([X, fidelity_col], dim=-1)
138
+ return self._model.posterior(X)
139
+
140
+ @property
141
+ def model(self):
142
+ return self._model