Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d4d95e5
1
Parent(s):
7d19ebb
Add test for custom torch operator
Browse files- test/test_torch.py +35 -1
test/test_torch.py
CHANGED
@@ -36,7 +36,7 @@ class TestTorch(unittest.TestCase):
|
|
36 |
|
37 |
equations = get_hof(
|
38 |
"equation_file.csv",
|
39 |
-
n_features=2,
|
40 |
variables_names="x1 x2 x3".split(" "),
|
41 |
extra_sympy_mappings={},
|
42 |
output_torch_format=True,
|
@@ -68,3 +68,37 @@ class TestTorch(unittest.TestCase):
|
|
68 |
np.testing.assert_array_almost_equal(
|
69 |
true_out.detach(), torch_out.detach(), decimal=4
|
70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
equations = get_hof(
|
38 |
"equation_file.csv",
|
39 |
+
n_features=2, # TODO: Why is this 2 and not 3?
|
40 |
variables_names="x1 x2 x3".split(" "),
|
41 |
extra_sympy_mappings={},
|
42 |
output_torch_format=True,
|
|
|
68 |
np.testing.assert_array_almost_equal(
|
69 |
true_out.detach(), torch_out.detach(), decimal=4
|
70 |
)
|
71 |
+
|
72 |
+
def test_custom_operator(self):
|
73 |
+
X = np.random.randn(100, 3)
|
74 |
+
|
75 |
+
equations = pd.DataFrame(
|
76 |
+
{
|
77 |
+
"Equation": ["1.0", "mycustomoperator(x0)"],
|
78 |
+
"MSE": [1.0, 0.1],
|
79 |
+
"Complexity": [1, 2],
|
80 |
+
}
|
81 |
+
)
|
82 |
+
|
83 |
+
equations["Complexity MSE Equation".split(" ")].to_csv(
|
84 |
+
"equation_file_custom_operator.csv.bkup", sep="|"
|
85 |
+
)
|
86 |
+
|
87 |
+
equations = get_hof(
|
88 |
+
"equation_file_custom_operator.csv",
|
89 |
+
n_features=3,
|
90 |
+
variables_names="x1 x2 x3".split(" "),
|
91 |
+
extra_sympy_mappings={"mycustomoperator": sympy.sin},
|
92 |
+
extra_torch_mappings={"mycustomoperator": torch.sin},
|
93 |
+
output_torch_format=True,
|
94 |
+
multioutput=False,
|
95 |
+
nout=1,
|
96 |
+
selection=[0, 1, 2],
|
97 |
+
)
|
98 |
+
|
99 |
+
tformat = equations.iloc[-1].torch_format
|
100 |
+
np.testing.assert_almost_equal(
|
101 |
+
tformat(torch.tensor(X)).detach().numpy(),
|
102 |
+
np.sin(X[:, 0]), # Selection 1st feature
|
103 |
+
decimal=4,
|
104 |
+
)
|