MilesCranmer commited on
Commit
a6b7d35
1 Parent(s): 55e3b83

Fix selection index tests for jax/torch

Browse files
Files changed (2) hide show
  1. test/test_jax.py +3 -3
  2. test/test_torch.py +3 -3
test/test_jax.py CHANGED
@@ -19,7 +19,7 @@ class TestJAX(unittest.TestCase):
19
  def test_pipeline(self):
20
  X = np.random.randn(100, 10)
21
  equations = pd.DataFrame({
22
- 'Equation': ['1.0', 'cos(x1)', 'square(cos(x1))'],
23
  'MSE': [1.0, 0.1, 1e-5],
24
  'Complexity': [1, 2, 3]
25
  })
@@ -28,12 +28,12 @@ class TestJAX(unittest.TestCase):
28
  'equation_file.csv.bkup', sep='|')
29
 
30
  equations = get_hof(
31
- 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
32
  extra_sympy_mappings={}, output_jax_format=True,
33
  multioutput=False, nout=1, selection=[1, 2, 3])
34
 
35
  jformat = equations.iloc[-1].jax_format
36
  np.testing.assert_almost_equal(
37
  np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
38
- np.square(np.cos(X[:, 0]))
39
  )
 
19
  def test_pipeline(self):
20
  X = np.random.randn(100, 10)
21
  equations = pd.DataFrame({
22
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
23
  'MSE': [1.0, 0.1, 1e-5],
24
  'Complexity': [1, 2, 3]
25
  })
 
28
  'equation_file.csv.bkup', sep='|')
29
 
30
  equations = get_hof(
31
+ 'equation_file.csv', n_features=2, variables_names='x1 x2 x3'.split(' '),
32
  extra_sympy_mappings={}, output_jax_format=True,
33
  multioutput=False, nout=1, selection=[1, 2, 3])
34
 
35
  jformat = equations.iloc[-1].jax_format
36
  np.testing.assert_almost_equal(
37
  np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
38
+ np.square(np.cos(X[:, 1])) # Select feature 1
39
  )
test/test_torch.py CHANGED
@@ -18,7 +18,7 @@ class TestTorch(unittest.TestCase):
18
  def test_pipeline(self):
19
  X = np.random.randn(100, 10)
20
  equations = pd.DataFrame({
21
- 'Equation': ['1.0', 'cos(x1)', 'square(cos(x1))'],
22
  'MSE': [1.0, 0.1, 1e-5],
23
  'Complexity': [1, 2, 3]
24
  })
@@ -27,12 +27,12 @@ class TestTorch(unittest.TestCase):
27
  'equation_file.csv.bkup', sep='|')
28
 
29
  equations = get_hof(
30
- 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
31
  extra_sympy_mappings={}, output_torch_format=True,
32
  multioutput=False, nout=1, selection=[1, 2, 3])
33
 
34
  tformat = equations.iloc[-1].torch_format
35
  np.testing.assert_almost_equal(
36
  tformat(torch.tensor(X)).detach().numpy(),
37
- np.square(np.cos(X[:, 0]))
38
  )
 
18
  def test_pipeline(self):
19
  X = np.random.randn(100, 10)
20
  equations = pd.DataFrame({
21
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
22
  'MSE': [1.0, 0.1, 1e-5],
23
  'Complexity': [1, 2, 3]
24
  })
 
27
  'equation_file.csv.bkup', sep='|')
28
 
29
  equations = get_hof(
30
+ 'equation_file.csv', n_features=2, variables_names='x1 x2 x3'.split(' '),
31
  extra_sympy_mappings={}, output_torch_format=True,
32
  multioutput=False, nout=1, selection=[1, 2, 3])
33
 
34
  tformat = equations.iloc[-1].torch_format
35
  np.testing.assert_almost_equal(
36
  tformat(torch.tensor(X)).detach().numpy(),
37
+ np.square(np.cos(X[:, 1])) #Selection 1st feature
38
  )