MilesCranmer commited on
Commit
52a6b5b
2 Parent(s): d09ade8 b14e38a

Merge pull request #186 from mkitti/mkitti/set_julia_project

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +21 -5
  2. pysr/sr.py +1 -1
pysr/julia_helpers.py CHANGED
@@ -12,13 +12,24 @@ def install(julia_project=None, quiet=False): # pragma: no cover
12
 
13
  Also updates the local Julia registry.
14
  """
 
 
 
 
 
 
 
15
  import julia
16
 
17
  julia.install(quiet=quiet)
18
 
19
- julia_project, is_shared = _get_julia_project(julia_project)
 
 
 
 
 
20
 
21
- Main = init_julia()
22
  Main.eval("using Pkg")
23
 
24
  io = "devnull" if quiet else "stderr"
@@ -72,10 +83,16 @@ def is_julia_version_greater_eq(Main, version="1.6"):
72
  return Main.eval(f'VERSION >= v"{version}"')
73
 
74
 
75
- def init_julia():
76
  """Initialize julia binary, turning off compiled modules if needed."""
77
  from julia.core import JuliaInfo, UnsupportedPythonError
78
 
 
 
 
 
 
 
79
  try:
80
  info = JuliaInfo.load(julia="julia")
81
  except FileNotFoundError:
@@ -110,13 +127,12 @@ def _add_sr_to_julia_project(Main, io_arg):
110
  url="https://github.com/MilesCranmer/SymbolicRegression.jl",
111
  rev="v" + __symbolic_regression_jl_version__,
112
  )
113
- Main.eval(f"Pkg.add(sr_spec, {io_arg})")
114
  Main.clustermanagers_spec = Main.PackageSpec(
115
  name="ClusterManagers",
116
  url="https://github.com/JuliaParallel/ClusterManagers.jl",
117
  rev="14e7302f068794099344d5d93f71979aaf4fbeb3",
118
  )
119
- Main.eval(f"Pkg.add(clustermanagers_spec, {io_arg})")
120
 
121
 
122
  def _escape_filename(filename):
 
12
 
13
  Also updates the local Julia registry.
14
  """
15
+ # Set JULIA_PROJECT so that we install in the pysr environment
16
+ julia_project, is_shared = _get_julia_project(julia_project)
17
+ if is_shared:
18
+ os.environ["JULIA_PROJECT"] = "@" + str(julia_project)
19
+ else:
20
+ os.environ["JULIA_PROJECT"] = str(julia_project)
21
+
22
  import julia
23
 
24
  julia.install(quiet=quiet)
25
 
26
+ if is_shared:
27
+ # is_shared is only true if the julia_project arg was None
28
+ # See _get_julia_project
29
+ Main = init_julia(None)
30
+ else:
31
+ Main = init_julia(julia_project)
32
 
 
33
  Main.eval("using Pkg")
34
 
35
  io = "devnull" if quiet else "stderr"
 
83
  return Main.eval(f'VERSION >= v"{version}"')
84
 
85
 
86
+ def init_julia(julia_project=None):
87
  """Initialize julia binary, turning off compiled modules if needed."""
88
  from julia.core import JuliaInfo, UnsupportedPythonError
89
 
90
+ julia_project, is_shared = _get_julia_project(julia_project)
91
+ if is_shared:
92
+ os.environ["JULIA_PROJECT"] = "@" + str(julia_project)
93
+ else:
94
+ os.environ["JULIA_PROJECT"] = str(julia_project)
95
+
96
  try:
97
  info = JuliaInfo.load(julia="julia")
98
  except FileNotFoundError:
 
127
  url="https://github.com/MilesCranmer/SymbolicRegression.jl",
128
  rev="v" + __symbolic_regression_jl_version__,
129
  )
 
130
  Main.clustermanagers_spec = Main.PackageSpec(
131
  name="ClusterManagers",
132
  url="https://github.com/JuliaParallel/ClusterManagers.jl",
133
  rev="14e7302f068794099344d5d93f71979aaf4fbeb3",
134
  )
135
+ Main.eval(f"Pkg.add([sr_spec, clustermanagers_spec], {io_arg})")
136
 
137
 
138
  def _escape_filename(filename):
pysr/sr.py CHANGED
@@ -1430,7 +1430,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1430
  if multithreading:
1431
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1432
 
1433
- Main = init_julia()
1434
 
1435
  if cluster_manager is not None:
1436
  Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
 
1430
  if multithreading:
1431
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1432
 
1433
+ Main = init_julia(self.julia_project)
1434
 
1435
  if cluster_manager is not None:
1436
  Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")