MilesCranmer commited on
Commit
085fe48
·
unverified ·
1 Parent(s): bbf18ee

test: improve coverage

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +43 -19
pysr/test/test.py CHANGED
@@ -183,26 +183,38 @@ class TestPipeline(unittest.TestCase):
183
  self.assertLessEqual(mse2, 1e-4)
184
 
185
  def test_custom_variable_complexity(self):
186
- for case in (1, 2):
187
- y = self.X[:, [0, 1]]
188
- model = PySRRegressor(
189
- binary_operators=["+"],
190
- verbosity=0,
191
- **self.default_test_kwargs,
192
- early_stop_condition=f"stop_if_{case}(l, c) = l < 1e-8 && c <= {3 if case == 1 else 2}",
193
- )
194
- if case == 1:
195
- complexity_of_variables = [2, 3]
196
- elif case == 2:
197
- complexity_of_variables = 2
198
- model.fit(
199
- self.X[:, [0, 1]], y, complexity_of_variables=complexity_of_variables
200
- )
201
- self.assertLessEqual(model.get_best()[0]["loss"], 1e-8)
202
- self.assertLessEqual(model.get_best()[1]["loss"], 1e-8)
 
 
 
 
 
 
 
 
 
 
203
 
204
- self.assertEqual(model.get_best()[0]["complexity"], 2)
205
- self.assertEqual(model.get_best()[1]["complexity"], 3 if case == 1 else 2)
 
 
206
 
207
  def test_error_message_custom_variable_complexity(self):
208
  X = np.ones((10, 2))
@@ -215,6 +227,18 @@ class TestPipeline(unittest.TestCase):
215
  "number of elements in `complexity_of_variables`", str(cm.exception)
216
  )
217
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  def test_multioutput_weighted_with_callable_temp_equation(self):
219
  X = self.X.copy()
220
  y = X[:, [0, 1]] ** 2
 
183
  self.assertLessEqual(mse2, 1e-4)
184
 
185
  def test_custom_variable_complexity(self):
186
+ for outer in (True, False):
187
+ for case in (1, 2):
188
+ y = self.X[:, [0, 1]]
189
+ if case == 1:
190
+ kwargs = dict(complexity_of_variables=[2, 3])
191
+ elif case == 2:
192
+ kwargs = dict(complexity_of_variables=2)
193
+
194
+ if outer:
195
+ outer_kwargs = kwargs
196
+ inner_kwargs = dict()
197
+ else:
198
+ outer_kwargs = dict()
199
+ inner_kwargs = kwargs
200
+
201
+ model = PySRRegressor(
202
+ binary_operators=["+"],
203
+ verbosity=0,
204
+ **self.default_test_kwargs,
205
+ early_stop_condition=(
206
+ f"stop_if_{case}(l, c) = l < 1e-8 && c <= {3 if case == 1 else 2}"
207
+ ),
208
+ **outer_kwargs,
209
+ )
210
+ model.fit(self.X[:, [0, 1]], y, **inner_kwargs)
211
+ self.assertLessEqual(model.get_best()[0]["loss"], 1e-8)
212
+ self.assertLessEqual(model.get_best()[1]["loss"], 1e-8)
213
 
214
+ self.assertEqual(model.get_best()[0]["complexity"], 2)
215
+ self.assertEqual(
216
+ model.get_best()[1]["complexity"], 3 if case == 1 else 2
217
+ )
218
 
219
  def test_error_message_custom_variable_complexity(self):
220
  X = np.ones((10, 2))
 
227
  "number of elements in `complexity_of_variables`", str(cm.exception)
228
  )
229
 
230
+ def test_error_message_both_variable_complexity(self):
231
+ X = np.ones((10, 2))
232
+ y = np.ones((10,))
233
+ model = PySRRegressor(complexity_of_variables=[1, 2])
234
+ with self.assertRaises(ValueError) as cm:
235
+ model.fit(X, y, complexity_of_variables=[1, 2, 3])
236
+
237
+ self.assertIn(
238
+ "You cannot set `complexity_of_variables` at both `fit` and `__init__`.",
239
+ str(cm.exception),
240
+ )
241
+
242
  def test_multioutput_weighted_with_callable_temp_equation(self):
243
  X = self.X.copy()
244
  y = X[:, [0, 1]] ** 2