MilesCranmer commited on
Commit
8d0381b
1 Parent(s): 44ff874

Remove unneeded copying

Browse files
Files changed (1) hide show
  1. pysr/sr.py +9 -40
pysr/sr.py CHANGED
@@ -7,7 +7,6 @@ import shutil
7
  import sys
8
  import tempfile
9
  import warnings
10
- from collections import namedtuple
11
  from datetime import datetime
12
  from io import StringIO
13
  from multiprocessing import cpu_count
@@ -1723,31 +1722,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1723
  else:
1724
  jl_y_variable_names = None
1725
 
1726
- # Because we call some multi-threading code, we first create the arguments.
1727
- # We do a deep copy of them **within Julia**, so that
1728
- # Python's garbage collection is unaware of them.
1729
- jl._equation_search_args = (jl_X, jl_y)
1730
- jl._equation_search_kwargs = namedtuple(
1731
- "equation_search_kwargs",
1732
- (
1733
- "weights",
1734
- "niterations",
1735
- "variable_names",
1736
- "display_variable_names",
1737
- "y_variable_names",
1738
- "X_units",
1739
- "y_units",
1740
- "options",
1741
- "numprocs",
1742
- "parallelism",
1743
- "saved_state",
1744
- "return_state",
1745
- "addprocs_function",
1746
- "heap_size_hint_in_bytes",
1747
- "progress",
1748
- "verbosity",
1749
- ),
1750
- )(
1751
  weights=jl_weights,
1752
  niterations=int(self.niterations),
1753
  variable_names=jl_array(self.feature_names_in_.tolist()),
@@ -1765,21 +1743,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1765
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1766
  verbosity=int(self.verbosity),
1767
  )
1768
- jl.PythonCall.GC.disable()
1769
- output_stream = jl.seval(
1770
- """
1771
- let args = deepcopy(_equation_search_args), kwargs=deepcopy(_equation_search_kwargs)
1772
- out = SymbolicRegression.equation_search(args...; kwargs...)
1773
- buf = IOBuffer()
1774
- Serialization.serialize(buf, out)
1775
- take!(buf)
1776
- end
1777
- """
1778
- )
1779
  jl.PythonCall.GC.enable()
1780
- jl._equation_search_args = None
1781
- jl._equation_search_kwargs = None
1782
- self.raw_julia_state_stream_ = np.array(output_stream)
 
 
1783
 
1784
  # Set attributes
1785
  self.equations_ = self.get_hof()
 
7
  import sys
8
  import tempfile
9
  import warnings
 
10
  from datetime import datetime
11
  from io import StringIO
12
  from multiprocessing import cpu_count
 
1722
  else:
1723
  jl_y_variable_names = None
1724
 
1725
+ jl.PythonCall.GC.disable()
1726
+ out = SymbolicRegression.equation_search(
1727
+ jl_X,
1728
+ jl_y,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1729
  weights=jl_weights,
1730
  niterations=int(self.niterations),
1731
  variable_names=jl_array(self.feature_names_in_.tolist()),
 
1743
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1744
  verbosity=int(self.verbosity),
1745
  )
 
 
 
 
 
 
 
 
 
 
 
1746
  jl.PythonCall.GC.enable()
1747
+
1748
+ # Serialize output (for pickling)
1749
+ buf = jl.IOBuffer()
1750
+ jl.Serialization.serialize(buf, out)
1751
+ self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
1752
 
1753
  # Set attributes
1754
  self.equations_ = self.get_hof()