MilesCranmer commited on
Commit
c86910d
·
unverified ·
1 Parent(s): 1d19a08

Add full objective test

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +15 -3
pysr/test/test.py CHANGED
@@ -72,8 +72,10 @@ class TestPipeline(unittest.TestCase):
72
  print(model.equations_)
73
  self.assertLessEqual(model.get_best()["loss"], 1e-4)
74
 
75
- def test_multiprocessing_turbo(self):
 
76
  y = self.X[:, 0]
 
77
  model = PySRRegressor(
78
  **self.default_test_kwargs,
79
  # Turbo needs to work with unsafe operators:
@@ -81,11 +83,21 @@ class TestPipeline(unittest.TestCase):
81
  procs=2,
82
  multithreading=False,
83
  turbo=True,
84
- early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1",
 
 
 
 
 
 
 
 
85
  )
86
  model.fit(self.X, y)
87
  print(model.equations_)
88
- self.assertLessEqual(model.equations_.iloc[-1]["loss"], 1e-4)
 
 
89
 
90
  def test_high_precision_search_custom_loss(self):
91
  y = 1.23456789 * self.X[:, 0]
 
72
  print(model.equations_)
73
  self.assertLessEqual(model.get_best()["loss"], 1e-4)
74
 
75
+ def test_multiprocessing_turbo_custom_objective(self):
76
+ rstate = np.random.RandomState(0)
77
  y = self.X[:, 0]
78
+ y += rstate.randn(*y.shape) * 1e-4
79
  model = PySRRegressor(
80
  **self.default_test_kwargs,
81
  # Turbo needs to work with unsafe operators:
 
83
  procs=2,
84
  multithreading=False,
85
  turbo=True,
86
+ early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 1",
87
+ full_objective="""
88
+ function my_objective(tree::Node{T}, dataset::Dataset{T}, options::Options) where T
89
+ prediction, flag = eval_tree_array(tree, dataset.X, options)
90
+ !flag && return T(Inf)
91
+ abs3(x) = abs(x) ^ 3
92
+ return sum(abs3, prediction .- dataset.y) / length(prediction)
93
+ end
94
+ """,
95
  )
96
  model.fit(self.X, y)
97
  print(model.equations_)
98
+ best_loss = model.equations_.iloc[-1]["loss"]
99
+ self.assertLessEqual(best_loss, 1e-10)
100
+ self.assertGreaterEqual(best_loss, 0.0)
101
 
102
  def test_high_precision_search_custom_loss(self):
103
  y = 1.23456789 * self.X[:, 0]