Spaces:
Running
Running
Commit
·
05cf610
1
Parent(s):
6add8e3
Test sympy export to jax
Browse files- .github/workflows/CI.yml +1 -0
- test/test.py +14 -2
.github/workflows/CI.yml
CHANGED
|
@@ -58,6 +58,7 @@ jobs:
|
|
| 58 |
run: |
|
| 59 |
python -m pip install --upgrade pip
|
| 60 |
pip install -r requirements.txt
|
|
|
|
| 61 |
python setup.py install
|
| 62 |
shell: bash
|
| 63 |
- name: "Run tests"
|
|
|
|
| 58 |
run: |
|
| 59 |
python -m pip install --upgrade pip
|
| 60 |
pip install -r requirements.txt
|
| 61 |
+
pip install jax jaxlib # (optional import)
|
| 62 |
python setup.py install
|
| 63 |
shell: bash
|
| 64 |
- name: "Run tests"
|
test/test.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
from pysr import pysr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
X = np.random.randn(100, 5)
|
| 4 |
|
| 5 |
print("Test 1 - defaults; simple linear relation")
|
|
@@ -27,6 +31,14 @@ equations = pysr(X, y,
|
|
| 27 |
unary_operators=[], binary_operators=["plus"],
|
| 28 |
niterations=10,
|
| 29 |
user_input=False)
|
| 30 |
-
|
| 31 |
print(equations)
|
| 32 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from pysr import pysr, sympy2jax
|
| 3 |
+
from jax import numpy as jnp
|
| 4 |
+
from jax import random
|
| 5 |
+
from jax import grad
|
| 6 |
+
import sympy
|
| 7 |
X = np.random.randn(100, 5)
|
| 8 |
|
| 9 |
print("Test 1 - defaults; simple linear relation")
|
|
|
|
| 31 |
unary_operators=[], binary_operators=["plus"],
|
| 32 |
niterations=10,
|
| 33 |
user_input=False)
|
|
|
|
| 34 |
print(equations)
|
| 35 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
| 36 |
+
|
| 37 |
+
print("Test 4 - text JAX export")
|
| 38 |
+
x, y, z = sympy.symbols('x y z')
|
| 39 |
+
cosx = 1.0 * sympy.cos(x) + y
|
| 40 |
+
key = random.PRNGKey(0)
|
| 41 |
+
X = random.normal(key, (1000, 2))
|
| 42 |
+
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
| 43 |
+
f, params = sympy2jax(cosx, [x])
|
| 44 |
+
assert jnp.all(jnp.isclose(f(X, params), true)).item()
|