MilesCranmer commited on
Commit
b846feb
·
1 Parent(s): 4ee8cdb

Refactor use of backend as more module-like

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +14 -13
  2. pysr/sr.py +9 -10
pysr/julia_helpers.py CHANGED
@@ -94,19 +94,15 @@ def install(julia_project=None, quiet=False): # pragma: no cover
94
  )
95
 
96
 
97
- def _import_error_string(julia_project=None):
98
- s = """
 
99
  Required dependencies are not installed or built. Run the following code in the Python REPL:
100
 
101
  >>> import pysr
102
  >>> pysr.install()
103
  """
104
-
105
- if julia_project is not None:
106
- s += f"""
107
- Tried to activate project {julia_project} but failed."""
108
-
109
- return s
110
 
111
 
112
  def _process_julia_project(julia_project):
@@ -171,7 +167,7 @@ def init_julia(julia_project=None, quiet=False):
171
  )
172
 
173
  if not info.is_pycall_built():
174
- raise ImportError(_import_error_string())
175
 
176
  Main = None
177
  try:
@@ -259,19 +255,24 @@ def _load_cluster_manager(Main, cluster_manager):
259
  return Main.eval(f"addprocs_{cluster_manager}")
260
 
261
 
262
- def _update_julia_project(Main, julia_project, is_shared, io_arg):
263
  try:
264
  if is_shared:
265
  _add_sr_to_julia_project(Main, io_arg)
266
  Main.eval(f"Pkg.resolve({io_arg})")
267
  except (JuliaError, RuntimeError) as e:
268
- raise ImportError(_import_error_string(julia_project)) from e
269
 
270
 
271
- def _load_backend(Main, julia_project):
272
  try:
 
273
  Main.eval("using SymbolicRegression")
274
  except (JuliaError, RuntimeError) as e:
275
- raise ImportError(_import_error_string(julia_project)) from e
276
 
277
  _backend_version_assertion(Main)
 
 
 
 
 
94
  )
95
 
96
 
97
+ def _import_error():
98
+ raise ImportError(
99
+ """
100
  Required dependencies are not installed or built. Run the following code in the Python REPL:
101
 
102
  >>> import pysr
103
  >>> pysr.install()
104
  """
105
+ )
 
 
 
 
 
106
 
107
 
108
  def _process_julia_project(julia_project):
 
167
  )
168
 
169
  if not info.is_pycall_built():
170
+ _import_error()
171
 
172
  Main = None
173
  try:
 
255
  return Main.eval(f"addprocs_{cluster_manager}")
256
 
257
 
258
+ def _update_julia_project(Main, is_shared, io_arg):
259
  try:
260
  if is_shared:
261
  _add_sr_to_julia_project(Main, io_arg)
262
  Main.eval(f"Pkg.resolve({io_arg})")
263
  except (JuliaError, RuntimeError) as e:
264
+ raise ImportError(_import_error()) from e
265
 
266
 
267
+ def _load_backend(Main):
268
  try:
269
+ # Load namespace, so that various internal operators work:
270
  Main.eval("using SymbolicRegression")
271
  except (JuliaError, RuntimeError) as e:
272
+ raise ImportError(_import_error()) from e
273
 
274
  _backend_version_assertion(Main)
275
+
276
+ from julia import SymbolicRegression
277
+
278
+ return SymbolicRegression
pysr/sr.py CHANGED
@@ -1467,18 +1467,17 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1467
  Main.eval(
1468
  f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
1469
  )
1470
- from julia.api import JuliaError
1471
 
1472
  if self.update:
1473
- _update_julia_project(Main, julia_project, is_shared, io_arg)
1474
 
1475
- _load_backend(Main, julia_project)
1476
 
1477
- Main.plus = Main.eval("(+)")
1478
- Main.sub = Main.eval("(-)")
1479
- Main.mult = Main.eval("(*)")
1480
- Main.pow = Main.eval("(^)")
1481
- Main.div = Main.eval("(/)")
1482
 
1483
  # TODO(mcranmer): These functions should be part of this class.
1484
  binary_operators, unary_operators = _maybe_create_inline_operators(
@@ -1535,7 +1534,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1535
 
1536
  # Call to Julia backend.
1537
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1538
- options = Main.Options(
1539
  binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
1540
  unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
1541
  bin_constraints=bin_constraints,
@@ -1608,7 +1607,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1608
 
1609
  # Call to Julia backend.
1610
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
1611
- self.raw_julia_state_ = Main.EquationSearch(
1612
  Main.X,
1613
  Main.y,
1614
  weights=Main.weights,
 
1467
  Main.eval(
1468
  f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
1469
  )
 
1470
 
1471
  if self.update:
1472
+ _update_julia_project(Main, is_shared, io_arg)
1473
 
1474
+ SymbolicRegression = _load_backend(Main)
1475
 
1476
+ Main.plus = Main.eval("(+)")
1477
+ Main.sub = Main.eval("(-)")
1478
+ Main.mult = Main.eval("(*)")
1479
+ Main.pow = Main.eval("(^)")
1480
+ Main.div = Main.eval("(/)")
1481
 
1482
  # TODO(mcranmer): These functions should be part of this class.
1483
  binary_operators, unary_operators = _maybe_create_inline_operators(
 
1534
 
1535
  # Call to Julia backend.
1536
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1537
+ options = SymbolicRegression.Options(
1538
  binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
1539
  unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
1540
  bin_constraints=bin_constraints,
 
1607
 
1608
  # Call to Julia backend.
1609
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
1610
+ self.raw_julia_state_ = SymbolicRegression.EquationSearch(
1611
  Main.X,
1612
  Main.y,
1613
  weights=Main.weights,