Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d398bf9
1
Parent(s):
c9f1ebd
Add PySRRegressor versions of jax/torch tests
Browse files- test/test_jax.py +4 -2
- test/test_torch.py +6 -3
test/test_jax.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
-
from pysr import sympy2jax, get_hof
|
4 |
import pandas as pd
|
5 |
from jax import numpy as jnp
|
6 |
from jax import random
|
@@ -46,7 +46,9 @@ class TestJAX(unittest.TestCase):
|
|
46 |
selection=[1, 2, 3],
|
47 |
)
|
48 |
|
49 |
-
|
|
|
|
|
50 |
np.testing.assert_almost_equal(
|
51 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
52 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
+
from pysr import sympy2jax, get_hof, PySRRegressor
|
4 |
import pandas as pd
|
5 |
from jax import numpy as jnp
|
6 |
from jax import random
|
|
|
46 |
selection=[1, 2, 3],
|
47 |
)
|
48 |
|
49 |
+
model = PySRRegressor()
|
50 |
+
jformat = model.jax()
|
51 |
+
|
52 |
np.testing.assert_almost_equal(
|
53 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
54 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
test/test_torch.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
-
from pysr import sympy2torch, get_hof
|
5 |
import torch
|
6 |
import sympy
|
7 |
|
@@ -84,7 +84,7 @@ class TestTorch(unittest.TestCase):
|
|
84 |
"equation_file_custom_operator.csv.bkup", sep="|"
|
85 |
)
|
86 |
|
87 |
-
|
88 |
"equation_file_custom_operator.csv",
|
89 |
n_features=3,
|
90 |
variables_names="x1 x2 x3".split(" "),
|
@@ -96,7 +96,10 @@ class TestTorch(unittest.TestCase):
|
|
96 |
selection=[0, 1, 2],
|
97 |
)
|
98 |
|
99 |
-
|
|
|
|
|
|
|
100 |
np.testing.assert_almost_equal(
|
101 |
tformat(torch.tensor(X)).detach().numpy(),
|
102 |
np.sin(X[:, 0]), # Selection 1st feature
|
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
+
from pysr import sympy2torch, get_hof, PySRRegressor
|
5 |
import torch
|
6 |
import sympy
|
7 |
|
|
|
84 |
"equation_file_custom_operator.csv.bkup", sep="|"
|
85 |
)
|
86 |
|
87 |
+
get_hof(
|
88 |
"equation_file_custom_operator.csv",
|
89 |
n_features=3,
|
90 |
variables_names="x1 x2 x3".split(" "),
|
|
|
96 |
selection=[0, 1, 2],
|
97 |
)
|
98 |
|
99 |
+
model = PySRRegressor()
|
100 |
+
# Will automatically use the set global state from get_hof.
|
101 |
+
tformat = model.pytorch()
|
102 |
+
|
103 |
np.testing.assert_almost_equal(
|
104 |
tformat(torch.tensor(X)).detach().numpy(),
|
105 |
np.sin(X[:, 0]), # Selection 1st feature
|