Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
82b18ca
1
Parent(s):
44b5271
Include unit-test for multi-output equations
Browse files- test/test.py +53 -7
test/test.py
CHANGED
@@ -296,13 +296,24 @@ def manually_create_model(equations, feature_names=None):
|
|
296 |
)
|
297 |
|
298 |
# Set up internal parameters as if it had been fitted:
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
model.refresh()
|
308 |
|
@@ -581,6 +592,41 @@ class TestLaTeXTable(unittest.TestCase):
|
|
581 |
true_latex_table_str = self.create_true_latex(middle_part)
|
582 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
def test_latex_float_precision(self):
|
585 |
"""Test that we can print latex expressions with custom precision"""
|
586 |
expr = sympy.Float(4583.4485748, dps=50)
|
|
|
296 |
)
|
297 |
|
298 |
# Set up internal parameters as if it had been fitted:
|
299 |
+
if isinstance(equations, list):
|
300 |
+
# Multi-output.
|
301 |
+
model.equation_file_ = "equation_file.csv"
|
302 |
+
model.nout_ = len(equations)
|
303 |
+
model.selection_mask_ = None
|
304 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
305 |
+
for i in range(model.nout_):
|
306 |
+
equations[i]["complexity loss equation".split(" ")].to_csv(
|
307 |
+
f"equation_file.csv.out{i+1}.bkup", sep="|"
|
308 |
+
)
|
309 |
+
else:
|
310 |
+
model.equation_file_ = "equation_file.csv"
|
311 |
+
model.nout_ = 1
|
312 |
+
model.selection_mask_ = None
|
313 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
314 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
315 |
+
"equation_file.csv.bkup", sep="|"
|
316 |
+
)
|
317 |
|
318 |
model.refresh()
|
319 |
|
|
|
592 |
true_latex_table_str = self.create_true_latex(middle_part)
|
593 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
594 |
|
595 |
+
def test_multi_output(self):
|
596 |
+
equations1 = pd.DataFrame(
|
597 |
+
dict(
|
598 |
+
equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
|
599 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
600 |
+
complexity=[1, 2, 8],
|
601 |
+
)
|
602 |
+
)
|
603 |
+
equations2 = pd.DataFrame(
|
604 |
+
dict(
|
605 |
+
equation=["x1", "cos(x1)", "x0 * x0 * x1"],
|
606 |
+
loss=[1.32, 0.052, 2e-15],
|
607 |
+
complexity=[1, 2, 5],
|
608 |
+
)
|
609 |
+
)
|
610 |
+
equations = [equations1, equations2]
|
611 |
+
model = manually_create_model(equations)
|
612 |
+
middle_part_1 = r"""
|
613 |
+
$x_{0}$ & $1$ & $1.05$ & $0.0$ \\
|
614 |
+
$\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
615 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
|
616 |
+
"""
|
617 |
+
middle_part_2 = r"""
|
618 |
+
$x_{1}$ & $1$ & $1.32$ & $0.0$ \\
|
619 |
+
$\cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
|
620 |
+
$x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
|
621 |
+
"""
|
622 |
+
true_latex_table_str = "\n\n".join(
|
623 |
+
self.create_true_latex(part, include_score=True)
|
624 |
+
for part in [middle_part_1, middle_part_2]
|
625 |
+
)
|
626 |
+
latex_table_str = model.latex_table()
|
627 |
+
|
628 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
629 |
+
|
630 |
def test_latex_float_precision(self):
|
631 |
"""Test that we can print latex expressions with custom precision"""
|
632 |
expr = sympy.Float(4583.4485748, dps=50)
|