Spaces:
Running
Running
MilesCranmer
commited on
Try to fix nb sanitizer
Browse files- pysr/_cli/main.py +15 -6
- pysr/test/test.py +6 -5
- pysr/test/test_cli.py +6 -2
- pysr/test/test_dev.py +6 -2
- pysr/test/test_jax.py +6 -2
- pysr/test/test_startup.py +6 -2
- pysr/test/test_torch.py +6 -2
pysr/_cli/main.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import warnings
|
2 |
|
3 |
import click
|
@@ -55,19 +56,27 @@ def _tests(tests):
|
|
55 |
|
56 |
Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
|
57 |
"""
|
|
|
58 |
for test in tests.split(","):
|
59 |
if test == "main":
|
60 |
-
runtests()
|
61 |
elif test == "jax":
|
62 |
-
runtests_jax()
|
63 |
elif test == "torch":
|
64 |
-
runtests_torch()
|
65 |
elif test == "cli":
|
66 |
runtests_cli = get_runtests_cli()
|
67 |
-
runtests_cli()
|
68 |
elif test == "dev":
|
69 |
-
runtests_dev()
|
70 |
elif test == "startup":
|
71 |
-
runtests_startup()
|
72 |
else:
|
73 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
import warnings
|
3 |
|
4 |
import click
|
|
|
56 |
|
57 |
Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
|
58 |
"""
|
59 |
+
test_cases = []
|
60 |
for test in tests.split(","):
|
61 |
if test == "main":
|
62 |
+
test_cases.extend(runtests(just_tests=True))
|
63 |
elif test == "jax":
|
64 |
+
test_cases.extend(runtests_jax(just_tests=True))
|
65 |
elif test == "torch":
|
66 |
+
test_cases.extend(runtests_torch(just_tests=True))
|
67 |
elif test == "cli":
|
68 |
runtests_cli = get_runtests_cli()
|
69 |
+
test_cases.extend(runtests_cli(just_tests=True))
|
70 |
elif test == "dev":
|
71 |
+
test_cases.extend(runtests_dev(just_tests=True))
|
72 |
elif test == "startup":
|
73 |
+
test_cases.extend(runtests_startup(just_tests=True))
|
74 |
else:
|
75 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
76 |
+
|
77 |
+
loader = unittest.TestLoader()
|
78 |
+
suite = unittest.TestSuite()
|
79 |
+
for test_case in test_cases:
|
80 |
+
suite.addTests(loader.loadTestsFromTestCase(test_case))
|
81 |
+
runner = unittest.TextTestRunner()
|
82 |
+
return runner.run(suite)
|
pysr/test/test.py
CHANGED
@@ -1127,10 +1127,8 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1127 |
# TODO: Determine desired behavior if second .fit() call does not have units
|
1128 |
|
1129 |
|
1130 |
-
def runtests():
|
1131 |
"""Run all tests in test.py."""
|
1132 |
-
suite = unittest.TestSuite()
|
1133 |
-
loader = unittest.TestLoader()
|
1134 |
test_cases = [
|
1135 |
TestPipeline,
|
1136 |
TestBest,
|
@@ -1139,8 +1137,11 @@ def runtests():
|
|
1139 |
TestLaTeXTable,
|
1140 |
TestDimensionalConstraints,
|
1141 |
]
|
|
|
|
|
|
|
|
|
1142 |
for test_case in test_cases:
|
1143 |
-
|
1144 |
-
suite.addTests(tests)
|
1145 |
runner = unittest.TextTestRunner()
|
1146 |
return runner.run(suite)
|
|
|
1127 |
# TODO: Determine desired behavior if second .fit() call does not have units
|
1128 |
|
1129 |
|
1130 |
+
def runtests(just_tests=False):
|
1131 |
"""Run all tests in test.py."""
|
|
|
|
|
1132 |
test_cases = [
|
1133 |
TestPipeline,
|
1134 |
TestBest,
|
|
|
1137 |
TestLaTeXTable,
|
1138 |
TestDimensionalConstraints,
|
1139 |
]
|
1140 |
+
if just_tests:
|
1141 |
+
return test_cases
|
1142 |
+
suite = unittest.TestSuite()
|
1143 |
+
loader = unittest.TestLoader()
|
1144 |
for test_case in test_cases:
|
1145 |
+
suite.addTests(loader.loadTestsFromTestCase(test_case))
|
|
|
1146 |
runner = unittest.TextTestRunner()
|
1147 |
return runner.run(suite)
|
pysr/test/test_cli.py
CHANGED
@@ -68,11 +68,15 @@ def get_runtests():
|
|
68 |
self.assertEqual(result.output.strip(), expected.strip())
|
69 |
self.assertEqual(result.exit_code, 0)
|
70 |
|
71 |
-
def runtests():
|
72 |
"""Run all tests in cliTest.py."""
|
|
|
|
|
|
|
73 |
loader = unittest.TestLoader()
|
74 |
suite = unittest.TestSuite()
|
75 |
-
|
|
|
76 |
runner = unittest.TextTestRunner()
|
77 |
return runner.run(suite)
|
78 |
|
|
|
68 |
self.assertEqual(result.output.strip(), expected.strip())
|
69 |
self.assertEqual(result.exit_code, 0)
|
70 |
|
71 |
+
def runtests(just_tests=False):
|
72 |
"""Run all tests in cliTest.py."""
|
73 |
+
tests = [TestCli]
|
74 |
+
if just_tests:
|
75 |
+
return tests
|
76 |
loader = unittest.TestLoader()
|
77 |
suite = unittest.TestSuite()
|
78 |
+
for test in tests:
|
79 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
80 |
runner = unittest.TextTestRunner()
|
81 |
return runner.run(suite)
|
82 |
|
pysr/test/test_dev.py
CHANGED
@@ -47,9 +47,13 @@ class TestDev(unittest.TestCase):
|
|
47 |
self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
|
48 |
|
49 |
|
50 |
-
def runtests():
|
|
|
|
|
|
|
51 |
suite = unittest.TestSuite()
|
52 |
loader = unittest.TestLoader()
|
53 |
-
|
|
|
54 |
runner = unittest.TextTestRunner()
|
55 |
return runner.run(suite)
|
|
|
47 |
self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
|
48 |
|
49 |
|
50 |
+
def runtests(just_tests=False):
|
51 |
+
tests = [TestDev]
|
52 |
+
if just_tests:
|
53 |
+
return tests
|
54 |
suite = unittest.TestSuite()
|
55 |
loader = unittest.TestLoader()
|
56 |
+
for test in tests:
|
57 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
58 |
runner = unittest.TextTestRunner()
|
59 |
return runner.run(suite)
|
pysr/test/test_jax.py
CHANGED
@@ -121,10 +121,14 @@ class TestJAX(unittest.TestCase):
|
|
121 |
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|
122 |
|
123 |
|
124 |
-
def runtests():
|
125 |
"""Run all tests in test_jax.py."""
|
|
|
|
|
|
|
126 |
loader = unittest.TestLoader()
|
127 |
suite = unittest.TestSuite()
|
128 |
-
|
|
|
129 |
runner = unittest.TextTestRunner()
|
130 |
return runner.run(suite)
|
|
|
121 |
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|
122 |
|
123 |
|
124 |
+
def runtests(just_tests=False):
|
125 |
"""Run all tests in test_jax.py."""
|
126 |
+
tests = [TestJAX]
|
127 |
+
if just_tests:
|
128 |
+
return tests
|
129 |
loader = unittest.TestLoader()
|
130 |
suite = unittest.TestSuite()
|
131 |
+
for test in tests:
|
132 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
133 |
runner = unittest.TextTestRunner()
|
134 |
return runner.run(suite)
|
pysr/test/test_startup.py
CHANGED
@@ -143,9 +143,13 @@ class TestStartup(unittest.TestCase):
|
|
143 |
self.assertEqual(result.returncode, 0)
|
144 |
|
145 |
|
146 |
-
def runtests():
|
|
|
|
|
|
|
147 |
suite = unittest.TestSuite()
|
148 |
loader = unittest.TestLoader()
|
149 |
-
|
|
|
150 |
runner = unittest.TextTestRunner()
|
151 |
return runner.run(suite)
|
|
|
143 |
self.assertEqual(result.returncode, 0)
|
144 |
|
145 |
|
146 |
+
def runtests(just_tests=False):
|
147 |
+
tests = [TestStartup]
|
148 |
+
if just_tests:
|
149 |
+
return tests
|
150 |
suite = unittest.TestSuite()
|
151 |
loader = unittest.TestLoader()
|
152 |
+
for test in tests:
|
153 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
154 |
runner = unittest.TextTestRunner()
|
155 |
return runner.run(suite)
|
pysr/test/test_torch.py
CHANGED
@@ -184,10 +184,14 @@ class TestTorch(unittest.TestCase):
|
|
184 |
np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
|
185 |
|
186 |
|
187 |
-
def runtests():
|
188 |
"""Run all tests in test_torch.py."""
|
|
|
|
|
|
|
189 |
loader = unittest.TestLoader()
|
190 |
suite = unittest.TestSuite()
|
191 |
-
|
|
|
192 |
runner = unittest.TextTestRunner()
|
193 |
return runner.run(suite)
|
|
|
184 |
np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
|
185 |
|
186 |
|
187 |
+
def runtests(just_tests=False):
|
188 |
"""Run all tests in test_torch.py."""
|
189 |
+
tests = [TestTorch]
|
190 |
+
if just_tests:
|
191 |
+
return tests
|
192 |
loader = unittest.TestLoader()
|
193 |
suite = unittest.TestSuite()
|
194 |
+
for test in tests:
|
195 |
+
suite.addTests(loader.loadTestsFromTestCase(test))
|
196 |
runner = unittest.TextTestRunner()
|
197 |
return runner.run(suite)
|