MilesCranmer commited on
Commit
3538029
1 Parent(s): 8af3119

Update docs for multi-output

Browse files
Files changed (1) hide show
  1. pysr/sr.py +10 -5
pysr/sr.py CHANGED
@@ -125,9 +125,12 @@ def pysr(X=None, y=None, weights=None,
125
  :param X: np.ndarray or pandas.DataFrame, 2D array. Rows are examples,
126
  columns are features. If pandas DataFrame, the columns are used
127
  for variable names (so make sure they don't contain spaces).
128
- :param y: np.ndarray, 1D array. Rows are examples.
129
- :param weights: np.ndarray, 1D array. Each row is how to weight the
130
- mean-square-error loss on weights.
 
 
 
131
  :param binary_operators: list, List of strings giving the binary operators
132
  in Julia's Base. Default is ["+", "-", "*", "/",].
133
  :param unary_operators: list, Same but for operators taking a single scalar.
@@ -227,8 +230,10 @@ def pysr(X=None, y=None, weights=None,
227
  delete_tempfiles argument.
228
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
229
  containing jax-callable functions and the default parameters in a jax array.
230
- :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
231
- (as strings).
 
 
232
 
233
  """
234
  if binary_operators is None:
 
125
  :param X: np.ndarray or pandas.DataFrame, 2D array. Rows are examples,
126
  columns are features. If pandas DataFrame, the columns are used
127
  for variable names (so make sure they don't contain spaces).
128
+ :param y: np.ndarray, 1D array (rows are examples) or 2D array (rows
129
+ are examples, columns are outputs). Putting in a 2D array will
130
+ trigger a search for equations for each feature of y.
131
+ :param weights: np.ndarray, same shape as y. Each element is how to
132
+ weight the mean-square-error loss for that particular element
133
+ of y.
134
  :param binary_operators: list, List of strings giving the binary operators
135
  in Julia's Base. Default is ["+", "-", "*", "/",].
136
  :param unary_operators: list, Same but for operators taking a single scalar.
 
230
  delete_tempfiles argument.
231
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
232
  containing jax-callable functions and the default parameters in a jax array.
233
+ :returns: pd.DataFrame or list, Results dataframe,
234
+ giving complexity, MSE, and equations (as strings), as well as functional
235
+ forms. If list, each element corresponds to a dataframe of equations
236
+ for each output.
237
 
238
  """
239
  if binary_operators is None: