Tolerance
Browse files- surrogate.py +4 -3
surrogate.py
CHANGED
@@ -39,6 +39,7 @@ PARAM_BOUNDS = [
|
|
39 |
{"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
|
40 |
]
|
41 |
|
|
|
42 |
|
43 |
class Parameterization(BaseModel):
|
44 |
N: float # int
|
@@ -77,7 +78,7 @@ class Parameterization(BaseModel):
|
|
77 |
|
78 |
if param["type"] == "range":
|
79 |
min_val, max_val = param["bounds"]
|
80 |
-
if not min_val <= v <= max_val:
|
81 |
raise ValueError(
|
82 |
f"{info.field_name} must be between {min_val} and {max_val}"
|
83 |
)
|
@@ -89,11 +90,11 @@ class Parameterization(BaseModel):
|
|
89 |
|
90 |
@model_validator(mode="after")
|
91 |
def check_constraints(self) -> "Parameterization":
|
92 |
-
if self.betas1 > self.betas2:
|
93 |
raise ValueError(
|
94 |
f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
|
95 |
)
|
96 |
-
if self.emb_scaler + self.pos_scaler > 1.0:
|
97 |
raise ValueError(
|
98 |
f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
|
99 |
)
|
|
|
39 |
{"name": "train_frac", "type": "range", "bounds": [0.01, 1.0]},
|
40 |
]
|
41 |
|
42 |
+
tol = 1e-6
|
43 |
|
44 |
class Parameterization(BaseModel):
|
45 |
N: float # int
|
|
|
78 |
|
79 |
if param["type"] == "range":
|
80 |
min_val, max_val = param["bounds"]
|
81 |
+
if not (min_val - tol) <= v <= (max_val + tol):
|
82 |
raise ValueError(
|
83 |
f"{info.field_name} must be between {min_val} and {max_val}"
|
84 |
)
|
|
|
90 |
|
91 |
@model_validator(mode="after")
|
92 |
def check_constraints(self) -> "Parameterization":
|
93 |
+
if (self.betas1 - tol) > (self.betas2 + tol):
|
94 |
raise ValueError(
|
95 |
f"Received betas1={self.betas1} which should be less than betas2={self.betas2}"
|
96 |
)
|
97 |
+
if self.emb_scaler + self.pos_scaler - tol > 1.0:
|
98 |
raise ValueError(
|
99 |
f"Received emb_scaler={self.emb_scaler} and pos_scaler={self.pos_scaler} which should sum to less than or equal to 1.0" # noqa: E501
|
100 |
)
|