MilesCranmer commited on
Commit
a117981
1 Parent(s): cbfdb9b

Test unit propagation

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +28 -1
pysr/test/test.py CHANGED
@@ -983,9 +983,36 @@ class TestDimensionalConstraints(unittest.TestCase):
983
  y_units,
984
  )
985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
 
987
  # TODO: add tests for:
988
- # - custom operators + dimensions
989
  # - no constants, so that it needs to find the right fraction
990
  # - custom dimensional_constraint_penalty
991
 
 
983
  y_units,
984
  )
985
 
986
+ def test_unit_propagation(self):
987
+ """Check that units are propagated correctly."""
988
+ X = np.ones((100, 3))
989
+ y = np.ones((100, 1))
990
+ model = PySRRegressor(
991
+ binary_operators=["+", "*"],
992
+ early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
993
+ **self.default_test_kwargs,
994
+ complexity_of_constants=10,
995
+ weight_mutate_constant=0.0,
996
+ should_optimize_constants=False,
997
+ multithreading=False,
998
+ deterministic=True,
999
+ procs=0,
1000
+ random_state=0,
1001
+ )
1002
+ model.fit(
1003
+ X,
1004
+ y,
1005
+ X_units=["m", "s", "A"],
1006
+ y_units=["m*A"],
1007
+ )
1008
+ best = model.get_best()
1009
+ self.assertIn("x0", best["equation"])
1010
+ self.assertNotIn("x1", best["equation"])
1011
+ self.assertIn("x2", best["equation"])
1012
+ self.assertEqual(best["complexity"], 3)
1013
+
1014
 
1015
  # TODO: add tests for:
 
1016
  # - no constants, so that it needs to find the right fraction
1017
  # - custom dimensional_constraint_penalty
1018