ravimohan19 commited on
Commit
28656f7
·
verified ·
1 Parent(s): 9660391

Upload optimizers/ax_optimizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. optimizers/ax_optimizer.py +126 -0
optimizers/ax_optimizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AX Platform optimizer backend for physics-informed BO."""
2
+
3
+ from typing import Callable, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from physics_informed_bo.config import OptimizationConfig
9
+ from physics_informed_bo.optimizers.base_optimizer import BaseOptimizer
10
+
11
+
12
+ class AXOptimizer(BaseOptimizer):
13
+ """AX Platform backend for structured experiment design.
14
+
15
+ AX provides a higher-level API for experiment management, including:
16
+ - Structured parameter spaces with constraints
17
+ - Multi-objective optimization
18
+ - Human-in-the-loop experiments
19
+ - Automatic model selection
20
+
21
+ The physics model is injected as a custom BoTorch model generator
22
+ within AX's modular framework.
23
+ """
24
+
25
+ def __init__(self, config: OptimizationConfig):
26
+ super().__init__(config)
27
+ self._experiment = None
28
+ self._gs = None
29
+ self._parameter_names: List[str] = []
30
+
31
+ def setup_experiment(
32
+ self,
33
+ parameter_names: List[str],
34
+ bounds: Dict[str, Tuple[float, float]],
35
+ objective_name: str = "objective",
36
+ minimize: bool = False,
37
+ outcome_constraints: Optional[List[str]] = None,
38
+ ) -> None:
39
+ """Set up an AX experiment with physics-informed model.
40
+
41
+ Args:
42
+ parameter_names: Names of input parameters.
43
+ bounds: Dict of {param_name: (lower, upper)}.
44
+ objective_name: Name of the objective metric.
45
+ minimize: Whether to minimize (True) or maximize (False).
46
+ outcome_constraints: List of outcome constraint strings e.g. ["metric >= 0.5"].
47
+ """
48
+ try:
49
+ from ax.service.ax_client import AxClient
50
+ from ax.service.utils.instantiation import ObjectiveProperties
51
+ except ImportError:
52
+ raise ImportError(
53
+ "AX Platform is required for AXOptimizer. "
54
+ "Install with: pip install ax-platform"
55
+ )
56
+
57
+ self._parameter_names = parameter_names
58
+
59
+ # Build parameter list for AX
60
+ parameters = []
61
+ for name in parameter_names:
62
+ lb, ub = bounds[name]
63
+ parameters.append(
64
+ {
65
+ "name": name,
66
+ "type": "range",
67
+ "bounds": [float(lb), float(ub)],
68
+ "value_type": "float",
69
+ }
70
+ )
71
+
72
+ self._ax_client = AxClient(verbose_logging=False)
73
+ self._ax_client.create_experiment(
74
+ name="physics_informed_bo",
75
+ parameters=parameters,
76
+ objectives={
77
+ objective_name: ObjectiveProperties(minimize=minimize),
78
+ },
79
+ )
80
+
81
+ def suggest(
82
+ self,
83
+ n_candidates: int = 1,
84
+ X_observed: Optional[Tensor] = None,
85
+ y_observed: Optional[Tensor] = None,
86
+ ) -> Tensor:
87
+ """Suggest next experiments using AX."""
88
+ if not hasattr(self, "_ax_client"):
89
+ raise RuntimeError("Call setup_experiment() before suggesting.")
90
+
91
+ candidates = []
92
+ trial_indices = []
93
+
94
+ for _ in range(n_candidates):
95
+ parameters, trial_index = self._ax_client.get_next_trial()
96
+ candidates.append([parameters[name] for name in self._parameter_names])
97
+ trial_indices.append(trial_index)
98
+
99
+ self._last_trial_indices = trial_indices
100
+ result = torch.tensor(candidates, dtype=torch.float64)
101
+
102
+ # Filter through physics constraints
103
+ result = self._filter_feasible(result)
104
+ return result[:n_candidates]
105
+
106
+ def update(self, X_new: Tensor, y_new: Tensor) -> None:
107
+ """Report observations back to AX."""
108
+ if not hasattr(self, "_ax_client"):
109
+ return
110
+
111
+ for i, (x, y_val) in enumerate(zip(X_new, y_new)):
112
+ if i < len(self._last_trial_indices):
113
+ trial_index = self._last_trial_indices[i]
114
+ self._ax_client.complete_trial(
115
+ trial_index=trial_index,
116
+ raw_data={"objective": (float(y_val), 0.0)},
117
+ )
118
+
119
+ def get_best_parameters(self) -> Dict:
120
+ """Get the best parameters found so far."""
121
+ best_params, values = self._ax_client.get_best_parameters()
122
+ return {"parameters": best_params, "values": values}
123
+
124
+ def get_trials_dataframe(self):
125
+ """Get all trials as a pandas DataFrame."""
126
+ return self._ax_client.get_trials_data_frame()