Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
8d0381b
1
Parent(s):
44ff874
Remove unneeded copying
Browse files- pysr/sr.py +9 -40
pysr/sr.py
CHANGED
@@ -7,7 +7,6 @@ import shutil
|
|
7 |
import sys
|
8 |
import tempfile
|
9 |
import warnings
|
10 |
-
from collections import namedtuple
|
11 |
from datetime import datetime
|
12 |
from io import StringIO
|
13 |
from multiprocessing import cpu_count
|
@@ -1723,31 +1722,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1723 |
else:
|
1724 |
jl_y_variable_names = None
|
1725 |
|
1726 |
-
|
1727 |
-
|
1728 |
-
|
1729 |
-
|
1730 |
-
jl._equation_search_kwargs = namedtuple(
|
1731 |
-
"equation_search_kwargs",
|
1732 |
-
(
|
1733 |
-
"weights",
|
1734 |
-
"niterations",
|
1735 |
-
"variable_names",
|
1736 |
-
"display_variable_names",
|
1737 |
-
"y_variable_names",
|
1738 |
-
"X_units",
|
1739 |
-
"y_units",
|
1740 |
-
"options",
|
1741 |
-
"numprocs",
|
1742 |
-
"parallelism",
|
1743 |
-
"saved_state",
|
1744 |
-
"return_state",
|
1745 |
-
"addprocs_function",
|
1746 |
-
"heap_size_hint_in_bytes",
|
1747 |
-
"progress",
|
1748 |
-
"verbosity",
|
1749 |
-
),
|
1750 |
-
)(
|
1751 |
weights=jl_weights,
|
1752 |
niterations=int(self.niterations),
|
1753 |
variable_names=jl_array(self.feature_names_in_.tolist()),
|
@@ -1765,21 +1743,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1765 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1766 |
verbosity=int(self.verbosity),
|
1767 |
)
|
1768 |
-
jl.PythonCall.GC.disable()
|
1769 |
-
output_stream = jl.seval(
|
1770 |
-
"""
|
1771 |
-
let args = deepcopy(_equation_search_args), kwargs=deepcopy(_equation_search_kwargs)
|
1772 |
-
out = SymbolicRegression.equation_search(args...; kwargs...)
|
1773 |
-
buf = IOBuffer()
|
1774 |
-
Serialization.serialize(buf, out)
|
1775 |
-
take!(buf)
|
1776 |
-
end
|
1777 |
-
"""
|
1778 |
-
)
|
1779 |
jl.PythonCall.GC.enable()
|
1780 |
-
|
1781 |
-
|
1782 |
-
|
|
|
|
|
1783 |
|
1784 |
# Set attributes
|
1785 |
self.equations_ = self.get_hof()
|
|
|
7 |
import sys
|
8 |
import tempfile
|
9 |
import warnings
|
|
|
10 |
from datetime import datetime
|
11 |
from io import StringIO
|
12 |
from multiprocessing import cpu_count
|
|
|
1722 |
else:
|
1723 |
jl_y_variable_names = None
|
1724 |
|
1725 |
+
jl.PythonCall.GC.disable()
|
1726 |
+
out = SymbolicRegression.equation_search(
|
1727 |
+
jl_X,
|
1728 |
+
jl_y,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1729 |
weights=jl_weights,
|
1730 |
niterations=int(self.niterations),
|
1731 |
variable_names=jl_array(self.feature_names_in_.tolist()),
|
|
|
1743 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1744 |
verbosity=int(self.verbosity),
|
1745 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1746 |
jl.PythonCall.GC.enable()
|
1747 |
+
|
1748 |
+
# Serialize output (for pickling)
|
1749 |
+
buf = jl.IOBuffer()
|
1750 |
+
jl.Serialization.serialize(buf, out)
|
1751 |
+
self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
|
1752 |
|
1753 |
# Set attributes
|
1754 |
self.equations_ = self.get_hof()
|