Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
5e0dd71
1
Parent(s):
0a3d3e9
Fix up other arguments in test
Browse files- test/test_jax.py +1 -3
- test/test_torch.py +4 -8
test/test_jax.py
CHANGED
@@ -38,9 +38,7 @@ class TestJAX(unittest.TestCase):
|
|
38 |
model = PySRRegressor(
|
39 |
equation_file="equation_file.csv",
|
40 |
output_jax_format=True,
|
41 |
-
|
42 |
-
multioutput=False,
|
43 |
-
nout=1,
|
44 |
selection=[1, 2, 3],
|
45 |
)
|
46 |
|
|
|
38 |
model = PySRRegressor(
|
39 |
equation_file="equation_file.csv",
|
40 |
output_jax_format=True,
|
41 |
+
variable_names="x1 x2 x3".split(" "),
|
|
|
|
|
42 |
selection=[1, 2, 3],
|
43 |
)
|
44 |
|
test/test_torch.py
CHANGED
@@ -37,13 +37,11 @@ class TestTorch(unittest.TestCase):
|
|
37 |
model = PySRRegressor(
|
38 |
model_selection="accuracy",
|
39 |
equation_file="equation_file.csv",
|
40 |
-
|
41 |
extra_sympy_mappings={},
|
42 |
output_torch_format=True,
|
43 |
-
multioutput=False,
|
44 |
-
nout=1,
|
45 |
-
selection=[1, 2, 3],
|
46 |
)
|
|
|
47 |
model.n_features = 2 # TODO: Why is this 2 and not 3?
|
48 |
model.using_pandas = False
|
49 |
model.refresh()
|
@@ -91,14 +89,12 @@ class TestTorch(unittest.TestCase):
|
|
91 |
model = PySRRegressor(
|
92 |
model_selection="accuracy",
|
93 |
equation_file="equation_file_custom_operator.csv",
|
94 |
-
|
95 |
extra_sympy_mappings={"mycustomoperator": sympy.sin},
|
96 |
extra_torch_mappings={"mycustomoperator": torch.sin},
|
97 |
output_torch_format=True,
|
98 |
-
multioutput=False,
|
99 |
-
nout=1,
|
100 |
-
selection=[0, 1, 2],
|
101 |
)
|
|
|
102 |
model.n_features = 3
|
103 |
model.using_pandas = False
|
104 |
model.refresh()
|
|
|
37 |
model = PySRRegressor(
|
38 |
model_selection="accuracy",
|
39 |
equation_file="equation_file.csv",
|
40 |
+
variable_names="x1 x2 x3".split(" "),
|
41 |
extra_sympy_mappings={},
|
42 |
output_torch_format=True,
|
|
|
|
|
|
|
43 |
)
|
44 |
+
model.selection = [1, 2, 3]
|
45 |
model.n_features = 2 # TODO: Why is this 2 and not 3?
|
46 |
model.using_pandas = False
|
47 |
model.refresh()
|
|
|
89 |
model = PySRRegressor(
|
90 |
model_selection="accuracy",
|
91 |
equation_file="equation_file_custom_operator.csv",
|
92 |
+
variable_names="x1 x2 x3".split(" "),
|
93 |
extra_sympy_mappings={"mycustomoperator": sympy.sin},
|
94 |
extra_torch_mappings={"mycustomoperator": torch.sin},
|
95 |
output_torch_format=True,
|
|
|
|
|
|
|
96 |
)
|
97 |
+
model.selection = [0, 1, 2]
|
98 |
model.n_features = 3
|
99 |
model.using_pandas = False
|
100 |
model.refresh()
|