Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
810bea9
1
Parent(s):
9854909
test: more typing info
Browse files- pysr/julia_helpers.py +5 -4
- pysr/sr.py +7 -7
pysr/julia_helpers.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
"""Functions for initializing the Julia environment and installing deps."""
|
2 |
|
3 |
-
from typing import Any, Callable, cast
|
4 |
|
5 |
import numpy as np
|
6 |
from juliacall import convert as jl_convert # type: ignore
|
|
|
7 |
|
8 |
from .deprecated import init_julia, install
|
9 |
from .julia_import import jl
|
@@ -26,7 +27,7 @@ def _escape_filename(filename):
|
|
26 |
return str_repr
|
27 |
|
28 |
|
29 |
-
def _load_cluster_manager(cluster_manager):
|
30 |
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
|
31 |
return jl.seval(f"addprocs_{cluster_manager}")
|
32 |
|
@@ -37,13 +38,13 @@ def jl_array(x):
|
|
37 |
return jl_convert(jl.Array, x)
|
38 |
|
39 |
|
40 |
-
def jl_serialize(obj):
|
41 |
buf = jl.IOBuffer()
|
42 |
Serialization.serialize(buf, obj)
|
43 |
return np.array(jl.take_b(buf))
|
44 |
|
45 |
|
46 |
-
def jl_deserialize(s):
|
47 |
if s is None:
|
48 |
return s
|
49 |
buf = jl.IOBuffer()
|
|
|
1 |
"""Functions for initializing the Julia environment and installing deps."""
|
2 |
|
3 |
+
from typing import Any, Callable, Union, cast
|
4 |
|
5 |
import numpy as np
|
6 |
from juliacall import convert as jl_convert # type: ignore
|
7 |
+
from numpy.typing import NDArray
|
8 |
|
9 |
from .deprecated import init_julia, install
|
10 |
from .julia_import import jl
|
|
|
27 |
return str_repr
|
28 |
|
29 |
|
30 |
+
def _load_cluster_manager(cluster_manager: str):
|
31 |
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
|
32 |
return jl.seval(f"addprocs_{cluster_manager}")
|
33 |
|
|
|
38 |
return jl_convert(jl.Array, x)
|
39 |
|
40 |
|
41 |
+
def jl_serialize(obj: Any) -> NDArray[np.uint8]:
|
42 |
buf = jl.IOBuffer()
|
43 |
Serialization.serialize(buf, obj)
|
44 |
return np.array(jl.take_b(buf))
|
45 |
|
46 |
|
47 |
+
def jl_deserialize(s: Union[NDArray[np.uint8], None]):
|
48 |
if s is None:
|
49 |
return s
|
50 |
buf = jl.IOBuffer()
|
pysr/sr.py
CHANGED
@@ -667,19 +667,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
667 |
```
|
668 |
"""
|
669 |
|
670 |
-
equations_:
|
671 |
n_features_in_: int
|
672 |
feature_names_in_: ArrayLike[str]
|
673 |
display_feature_names_in_: ArrayLike[str]
|
674 |
-
X_units_:
|
675 |
-
y_units_:
|
676 |
nout_: int
|
677 |
-
selection_mask_:
|
678 |
tempdir_: Path
|
679 |
equation_file_: Union[str, Path]
|
680 |
-
julia_state_stream_:
|
681 |
-
julia_options_stream_:
|
682 |
-
equation_file_contents_:
|
683 |
show_pickle_warnings_: bool
|
684 |
|
685 |
def __init__(
|
|
|
667 |
```
|
668 |
"""
|
669 |
|
670 |
+
equations_: Union[pd.DataFrame, List[pd.DataFrame], None]
|
671 |
n_features_in_: int
|
672 |
feature_names_in_: ArrayLike[str]
|
673 |
display_feature_names_in_: ArrayLike[str]
|
674 |
+
X_units_: Union[ArrayLike[str], None]
|
675 |
+
y_units_: Union[str, ArrayLike[str], None]
|
676 |
nout_: int
|
677 |
+
selection_mask_: Union[NDArray[np.bool_], None]
|
678 |
tempdir_: Path
|
679 |
equation_file_: Union[str, Path]
|
680 |
+
julia_state_stream_: Union[NDArray[np.uint8], None]
|
681 |
+
julia_options_stream_: Union[NDArray[np.uint8], None]
|
682 |
+
equation_file_contents_: Union[List[pd.DataFrame], None]
|
683 |
show_pickle_warnings_: bool
|
684 |
|
685 |
def __init__(
|