ravimohan19 commited on
Commit
0342a1d
·
verified ·
1 Parent(s): 28656f7

Upload optimizers/base_optimizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. optimizers/base_optimizer.py +66 -0
optimizers/base_optimizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for all optimizer backends."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from physics_informed_bo.config import OptimizationConfig
10
+ from physics_informed_bo.models.base import SurrogateModel
11
+ from physics_informed_bo.priors.physics_prior import PhysicsPrior
12
+
13
+
14
+ class BaseOptimizer(ABC):
15
+ """Abstract base class for optimizer backends (BoTorch, AX, BoFire)."""
16
+
17
+ def __init__(self, config: OptimizationConfig):
18
+ self.config = config
19
+ self._surrogate: Optional[SurrogateModel] = None
20
+ self._bounds: Optional[Tensor] = None
21
+ self._physics_prior: Optional[PhysicsPrior] = None
22
+
23
+ def set_surrogate(self, surrogate: SurrogateModel) -> None:
24
+ self._surrogate = surrogate
25
+
26
+ def set_bounds(self, bounds: Tensor) -> None:
27
+ """Set search space bounds. Shape: (2, d) where [0] = lower, [1] = upper."""
28
+ self._bounds = bounds
29
+
30
+ def set_physics_prior(self, physics_prior: PhysicsPrior) -> None:
31
+ self._physics_prior = physics_prior
32
+
33
+ @abstractmethod
34
+ def suggest(
35
+ self,
36
+ n_candidates: int = 1,
37
+ X_observed: Optional[Tensor] = None,
38
+ y_observed: Optional[Tensor] = None,
39
+ ) -> Tensor:
40
+ """Suggest next experiment(s) to run.
41
+
42
+ Args:
43
+ n_candidates: Number of candidates to suggest.
44
+ X_observed: All observed inputs so far.
45
+ y_observed: All observed outputs so far.
46
+
47
+ Returns:
48
+ Tensor of shape (n_candidates, d) with suggested experiments.
49
+ """
50
+
51
+ @abstractmethod
52
+ def update(self, X_new: Tensor, y_new: Tensor) -> None:
53
+ """Update the optimizer with new observations."""
54
+
55
+ def _filter_feasible(self, candidates: Tensor) -> Tensor:
56
+ """Filter candidates through physics constraints."""
57
+ if self._physics_prior is None:
58
+ return candidates
59
+ feasible_mask = self._physics_prior.check_feasibility(candidates)
60
+ feasible = candidates[feasible_mask]
61
+ if len(feasible) == 0:
62
+ # If no feasible candidates, return least-violating ones
63
+ violations = self._physics_prior.constraint_violation(candidates)
64
+ sorted_idx = violations.argsort()
65
+ return candidates[sorted_idx[:1]]
66
+ return feasible