sgbaird commited on
Commit
ddfbcf8
1 Parent(s): 5b742f8
Files changed (1) hide show
  1. surrogate.py +4 -3
surrogate.py CHANGED
@@ -39,6 +39,7 @@ PARAM_BOUNDS = [
39
  {"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
40
  ]
41
 
 
42
 
43
  class Parameterization(BaseModel):
44
  N: float # int
@@ -77,7 +78,7 @@ class Parameterization(BaseModel):
77
 
78
  if param["type"] == "range":
79
  min_val, max_val = param["bounds"]
80
- if not min_val <= v <= max_val:
81
  raise ValueError(
82
  f"{info.field_name} must be between {min_val} and {max_val}"
83
  )
@@ -89,11 +90,11 @@ class Parameterization(BaseModel):
89
 
90
  @model_validator(mode="after")
91
  def check_constraints(self) -> "Parameterization":
92
- if self.betas1 > self.betas2:
93
  raise ValueError(
94
  f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
95
  )
96
- if self.emb_scaler + self.pos_scaler > 1.0:
97
  raise ValueError(
98
  f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
99
  )
 
39
  {"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
40
  ]
41
 
42
+ tol = 1e-6
43
 
44
  class Parameterization(BaseModel):
45
  N: float # int
 
78
 
79
  if param["type"] == "range":
80
  min_val, max_val = param["bounds"]
81
+ if not (min_val - tol) <= v <= (max_val + tol):
82
  raise ValueError(
83
  f"{info.field_name} must be between {min_val} and {max_val}"
84
  )
 
90
 
91
  @model_validator(mode="after")
92
  def check_constraints(self) -> "Parameterization":
93
+ if (self.betas1 - tol) > (self.betas2 + tol):
94
  raise ValueError(
95
  f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
96
  )
97
+ if self.emb_scaler + self.pos_scaler - tol > 1.0:
98
  raise ValueError(
99
  f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
100
  )