MilesCranmer commited on
Commit
1efb6f4
1 Parent(s): b158e1f

Initial working version with PyJulia

Browse files
Files changed (1) hide show
  1. pysr/sr.py +62 -8
pysr/sr.py CHANGED
@@ -12,6 +12,7 @@ from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
14
 
 
15
  global_state = dict(
16
  equation_file="hall_of_fame.csv",
17
  n_features=None,
@@ -132,6 +133,7 @@ def pysr(
132
  Xresampled=None,
133
  precision=32,
134
  multithreading=None,
 
135
  ):
136
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
137
  Note: most default parameters have been tuned over several example
@@ -254,6 +256,8 @@ def pysr(
254
  :type precision: int
255
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
256
  :type multithreading: bool
 
 
257
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
258
  :type: pd.DataFrame/list
259
  """
@@ -272,7 +276,18 @@ def pysr(
272
  # or procs is set to 0 (serial mode).
273
  multithreading = procs != 0
274
 
275
- buffer_available = "buffer" in sys.stdout.__dir__()
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  if progress is not None:
278
  if progress and not buffer_available:
@@ -280,6 +295,11 @@ def pysr(
280
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
281
  )
282
  progress = False
 
 
 
 
 
283
  else:
284
  progress = buffer_available
285
 
@@ -321,7 +341,8 @@ def pysr(
321
  weights,
322
  y,
323
  )
324
- _check_for_julia_installation()
 
325
 
326
  if len(X) > 10000 and not batching:
327
  warnings.warn(
@@ -437,6 +458,7 @@ def pysr(
437
  denoise=denoise,
438
  precision=precision,
439
  multithreading=multithreading,
 
440
  )
441
 
442
  kwargs = {**_set_paths(tempdir), **kwargs}
@@ -457,7 +479,7 @@ def pysr(
457
 
458
  kwargs["need_install"] = False
459
 
460
- if not (manifest_filepath).is_file():
461
  kwargs["need_install"] = (not user_input) or _yesno(
462
  "I will install Julia packages using PySR's Project.toml file. OK?"
463
  )
@@ -471,10 +493,35 @@ def pysr(
471
 
472
  kwargs["constraints_str"] = _make_constraints_str(**kwargs)
473
  kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
474
- kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
  _create_julia_files(**kwargs)
477
- _final_pysr_process(**kwargs)
 
 
 
 
 
 
 
 
478
  _set_globals(**kwargs)
479
 
480
  equations = get_hof(**kwargs)
@@ -558,12 +605,16 @@ def _create_julia_files(
558
  need_install,
559
  update,
560
  multithreading,
 
561
  **kwargs,
562
  ):
563
  with open(hyperparam_filename, "w") as f:
564
  print(def_hyperparams, file=f)
565
- with open(dataset_filename, "w") as f:
566
- print(def_datasets, file=f)
 
 
 
567
  with open(runfile_filename, "w") as f:
568
  if julia_project is None:
569
  julia_project = pkg_directory
@@ -579,7 +630,10 @@ def _create_julia_files(
579
  print(f"Pkg.update()", file=f)
580
  print(f"using SymbolicRegression", file=f)
581
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
582
- print(f'include("{_escape_filename(dataset_filename)}")', file=f)
 
 
 
583
  if len(variable_names) == 0:
584
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
585
  else:
 
12
  import warnings
13
  from multiprocessing import cpu_count
14
 
15
+ Main = None
16
  global_state = dict(
17
  equation_file="hall_of_fame.csv",
18
  n_features=None,
 
133
  Xresampled=None,
134
  precision=32,
135
  multithreading=None,
136
+ pyjulia=False,
137
  ):
138
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
139
  Note: most default parameters have been tuned over several example
 
256
  :type precision: int
257
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
258
  :type multithreading: bool
259
+ :param pyjulia: Whether to use PyJulia instead of julia binary. PyJulia should reduce startup time for repeat calls.
260
+ :type pyjulia: bool
261
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
262
  :type: pd.DataFrame/list
263
  """
 
276
  # or procs is set to 0 (serial mode).
277
  multithreading = procs != 0
278
 
279
+ # Start up Julia:
280
+ global Main
281
+ if pyjulia and Main is None:
282
+ if not multithreading:
283
+ raise AssertionError(
284
+ "PyJulia does not support multiprocessing. Turn multithreading=True."
285
+ )
286
+
287
+ os.environ["JULIA_NUM_THREADS"] = str(procs)
288
+ from julia import Main
289
+
290
+ buffer_available = "buffer" in sys.stdout.__dir__() and not pyjulia
291
 
292
  if progress is not None:
293
  if progress and not buffer_available:
 
295
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
296
  )
297
  progress = False
298
+ if progress and pyjulia:
299
+ warnings.warn(
300
+ "Note: it looks like you are using PyJulia. The progress bar will be turned off."
301
+ )
302
+ progress = False
303
  else:
304
  progress = buffer_available
305
 
 
341
  weights,
342
  y,
343
  )
344
+ if not pyjulia:
345
+ _check_for_julia_installation()
346
 
347
  if len(X) > 10000 and not batching:
348
  warnings.warn(
 
458
  denoise=denoise,
459
  precision=precision,
460
  multithreading=multithreading,
461
+ pyjulia=pyjulia,
462
  )
463
 
464
  kwargs = {**_set_paths(tempdir), **kwargs}
 
479
 
480
  kwargs["need_install"] = False
481
 
482
+ if not (manifest_filepath).is_file() and not pyjulia:
483
  kwargs["need_install"] = (not user_input) or _yesno(
484
  "I will install Julia packages using PySR's Project.toml file. OK?"
485
  )
 
493
 
494
  kwargs["constraints_str"] = _make_constraints_str(**kwargs)
495
  kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
496
+
497
+ if pyjulia:
498
+ np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
499
+
500
+ Main.X = np.array(X, dtype=np_dtype).T
501
+ if len(y.shape) == 1:
502
+ Main.y = np.array(y, dtype=np_dtype)
503
+ else:
504
+ Main.y = np.array(y, dtype=np_dtype).T
505
+ if weights is not None:
506
+ if len(weights.shape) == 1:
507
+ Main.weights = np.array(weights, dtype=np_dtype)
508
+ else:
509
+ Main.weights = np.array(weights, dtype=np_dtype).T
510
+
511
+ kwargs["def_datasets"] = ""
512
+ else:
513
+ kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
514
 
515
  _create_julia_files(**kwargs)
516
+ if pyjulia:
517
+ # Read entire file as a single string:
518
+ with open(kwargs["runfile_filename"], "r") as f:
519
+ runfile_string = f.read()
520
+ print("Running main runfile in PyJulia!")
521
+ Main.eval(runfile_string)
522
+ else:
523
+ _final_pysr_process(**kwargs)
524
+
525
  _set_globals(**kwargs)
526
 
527
  equations = get_hof(**kwargs)
 
605
  need_install,
606
  update,
607
  multithreading,
608
+ pyjulia,
609
  **kwargs,
610
  ):
611
  with open(hyperparam_filename, "w") as f:
612
  print(def_hyperparams, file=f)
613
+
614
+ if not pyjulia:
615
+ with open(dataset_filename, "w") as f:
616
+ print(def_datasets, file=f)
617
+
618
  with open(runfile_filename, "w") as f:
619
  if julia_project is None:
620
  julia_project = pkg_directory
 
630
  print(f"Pkg.update()", file=f)
631
  print(f"using SymbolicRegression", file=f)
632
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
633
+
634
+ if not pyjulia:
635
+ print(f'include("{_escape_filename(dataset_filename)}")', file=f)
636
+
637
  if len(variable_names) == 0:
638
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
639
  else: