ravimohan19 commited on
Commit
d2b8aad
·
verified ·
1 Parent(s): 1f689b1

Upload utils/visualization.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/visualization.py +208 -0
utils/visualization.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for physics-informed Bayesian optimization."""
2
+
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import numpy as np
8
+
9
+
10
+ def plot_convergence(
11
+ campaign_df,
12
+ maximize: bool = True,
13
+ title: str = "Optimization Convergence",
14
+ figsize: Tuple[int, int] = (10, 6),
15
+ ):
16
+ """Plot the optimization convergence curve.
17
+
18
+ Args:
19
+ campaign_df: DataFrame from OptimizationCampaign.to_dataframe().
20
+ maximize: Whether the objective is being maximized.
21
+ title: Plot title.
22
+ figsize: Figure size.
23
+
24
+ Returns:
25
+ matplotlib Figure.
26
+ """
27
+ import matplotlib.pyplot as plt
28
+
29
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
30
+
31
+ objectives = campaign_df["objective"].values
32
+
33
+ # Left: all observations
34
+ ax1.plot(range(len(objectives)), objectives, "o-", alpha=0.6, markersize=4)
35
+ ax1.set_xlabel("Experiment Number")
36
+ ax1.set_ylabel("Objective")
37
+ ax1.set_title("All Observations")
38
+ ax1.grid(True, alpha=0.3)
39
+
40
+ # Right: best-so-far
41
+ if maximize:
42
+ best_so_far = np.maximum.accumulate(objectives)
43
+ else:
44
+ best_so_far = np.minimum.accumulate(objectives)
45
+
46
+ ax2.plot(range(len(best_so_far)), best_so_far, "s-", color="green", markersize=4)
47
+ ax2.set_xlabel("Experiment Number")
48
+ ax2.set_ylabel("Best Objective")
49
+ ax2.set_title("Best So Far")
50
+ ax2.grid(True, alpha=0.3)
51
+
52
+ fig.suptitle(title, fontsize=14)
53
+ plt.tight_layout()
54
+ return fig
55
+
56
+
57
+ def plot_surrogate_1d(
58
+ surrogate,
59
+ bounds: Tuple[float, float],
60
+ X_observed: Optional[Tensor] = None,
61
+ y_observed: Optional[Tensor] = None,
62
+ physics_fn: Optional[Callable] = None,
63
+ true_fn: Optional[Callable] = None,
64
+ n_grid: int = 200,
65
+ title: str = "Surrogate Model",
66
+ figsize: Tuple[int, int] = (10, 6),
67
+ ):
68
+ """Plot a 1D surrogate model with confidence intervals.
69
+
70
+ Args:
71
+ surrogate: A SurrogateModel instance.
72
+ bounds: (lower, upper) for the 1D input.
73
+ X_observed: Observed inputs (n, 1).
74
+ y_observed: Observed outputs (n, 1).
75
+ physics_fn: Optional physics model for comparison.
76
+ true_fn: Optional true function for comparison.
77
+ n_grid: Number of grid points.
78
+ title: Plot title.
79
+ figsize: Figure size.
80
+
81
+ Returns:
82
+ matplotlib Figure.
83
+ """
84
+ import matplotlib.pyplot as plt
85
+
86
+ fig, ax = plt.subplots(figsize=figsize)
87
+
88
+ X_grid = torch.linspace(bounds[0], bounds[1], n_grid).unsqueeze(-1).to(torch.float64)
89
+ mean, var = surrogate.predict(X_grid)
90
+ std = var.sqrt()
91
+
92
+ x_np = X_grid.squeeze().numpy()
93
+ mean_np = mean.squeeze().detach().numpy()
94
+ std_np = std.squeeze().detach().numpy()
95
+
96
+ # Surrogate prediction
97
+ ax.plot(x_np, mean_np, "b-", label="Surrogate Mean", linewidth=2)
98
+ ax.fill_between(
99
+ x_np,
100
+ mean_np - 2 * std_np,
101
+ mean_np + 2 * std_np,
102
+ alpha=0.2,
103
+ color="blue",
104
+ label="95% CI",
105
+ )
106
+
107
+ # Physics model
108
+ if physics_fn is not None:
109
+ with torch.no_grad():
110
+ physics_pred = physics_fn(X_grid).squeeze().numpy()
111
+ ax.plot(x_np, physics_pred, "r--", label="Physics Model", linewidth=1.5)
112
+
113
+ # True function
114
+ if true_fn is not None:
115
+ with torch.no_grad():
116
+ true_pred = true_fn(X_grid).squeeze().numpy()
117
+ ax.plot(x_np, true_pred, "k-", label="True Function", linewidth=1.5, alpha=0.7)
118
+
119
+ # Observations
120
+ if X_observed is not None and y_observed is not None:
121
+ ax.scatter(
122
+ X_observed.squeeze().numpy(),
123
+ y_observed.squeeze().numpy(),
124
+ c="red",
125
+ s=50,
126
+ zorder=5,
127
+ label="Observations",
128
+ edgecolors="black",
129
+ )
130
+
131
+ ax.set_xlabel("Input")
132
+ ax.set_ylabel("Output")
133
+ ax.set_title(title)
134
+ ax.legend()
135
+ ax.grid(True, alpha=0.3)
136
+ plt.tight_layout()
137
+ return fig
138
+
139
+
140
+ def plot_surrogate_2d(
141
+ surrogate,
142
+ bounds: Tensor,
143
+ param_names: Tuple[str, str] = ("x1", "x2"),
144
+ X_observed: Optional[Tensor] = None,
145
+ n_grid: int = 50,
146
+ title: str = "Surrogate Model (2D)",
147
+ figsize: Tuple[int, int] = (12, 5),
148
+ ):
149
+ """Plot 2D surrogate model as contour plots (mean and uncertainty).
150
+
151
+ Args:
152
+ surrogate: A SurrogateModel instance.
153
+ bounds: Tensor of shape (2, 2) with [lower, upper] bounds.
154
+ param_names: Names of the two parameters.
155
+ X_observed: Observed inputs (n, 2).
156
+ n_grid: Grid resolution per dimension.
157
+ title: Plot title.
158
+ figsize: Figure size.
159
+
160
+ Returns:
161
+ matplotlib Figure.
162
+ """
163
+ import matplotlib.pyplot as plt
164
+
165
+ x1 = torch.linspace(float(bounds[0, 0]), float(bounds[1, 0]), n_grid)
166
+ x2 = torch.linspace(float(bounds[0, 1]), float(bounds[1, 1]), n_grid)
167
+ X1, X2 = torch.meshgrid(x1, x2, indexing="ij")
168
+ X_grid = torch.stack([X1.flatten(), X2.flatten()], dim=-1).to(torch.float64)
169
+
170
+ mean, var = surrogate.predict(X_grid)
171
+ mean_2d = mean.squeeze().reshape(n_grid, n_grid).detach().numpy()
172
+ std_2d = var.sqrt().squeeze().reshape(n_grid, n_grid).detach().numpy()
173
+
174
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
175
+
176
+ # Mean
177
+ c1 = ax1.contourf(
178
+ X1.numpy(), X2.numpy(), mean_2d, levels=20, cmap="viridis"
179
+ )
180
+ plt.colorbar(c1, ax=ax1)
181
+ ax1.set_xlabel(param_names[0])
182
+ ax1.set_ylabel(param_names[1])
183
+ ax1.set_title("Predicted Mean")
184
+
185
+ # Uncertainty
186
+ c2 = ax2.contourf(
187
+ X1.numpy(), X2.numpy(), std_2d, levels=20, cmap="plasma"
188
+ )
189
+ plt.colorbar(c2, ax=ax2)
190
+ ax2.set_xlabel(param_names[0])
191
+ ax2.set_ylabel(param_names[1])
192
+ ax2.set_title("Predicted Std Dev")
193
+
194
+ # Overlay observations
195
+ if X_observed is not None:
196
+ for ax in [ax1, ax2]:
197
+ ax.scatter(
198
+ X_observed[:, 0].numpy(),
199
+ X_observed[:, 1].numpy(),
200
+ c="red",
201
+ s=30,
202
+ edgecolors="white",
203
+ zorder=5,
204
+ )
205
+
206
+ fig.suptitle(title, fontsize=14)
207
+ plt.tight_layout()
208
+ return fig