AutonLabTruth commited on
Commit
0dfd8e3
·
1 Parent(s): a9a1691

Refactored out paths and others

Browse files
Files changed (1) hide show
  1. pysr/sr.py +31 -22
pysr/sr.py CHANGED
@@ -207,16 +207,7 @@ def pysr(X=None, y=None, weights=None,
207
  if len(X.shape) == 1:
208
  X = X[:, None]
209
 
210
- # Check for potential errors before they happen
211
- assert len(unary_operators) + len(binary_operators) > 0
212
- assert len(X.shape) == 2
213
- assert len(y.shape) == 1
214
- assert X.shape[0] == y.shape[0]
215
- if weights is not None:
216
- assert len(weights.shape) == 1
217
- assert X.shape[0] == weights.shape[0]
218
- if use_custom_variable_names:
219
- assert len(variable_names) == X.shape[1]
220
 
221
  if select_k_features is not None:
222
  selection = run_feature_selection(X, y, select_k_features)
@@ -248,18 +239,8 @@ def pysr(X=None, y=None, weights=None,
248
  y = eval(eval_str)
249
  print("Running on", eval_str)
250
 
251
- # System-independent paths
252
- pkg_directory = Path(__file__).parents[1] / 'julia'
253
- pkg_filename = pkg_directory / "sr.jl"
254
- operator_filename = pkg_directory / "operators.jl"
255
-
256
- tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
257
- hyperparam_filename = tmpdir / f'hyperparams.jl'
258
- dataset_filename = tmpdir / f'dataset.jl'
259
- runfile_filename = tmpdir / f'runfile.jl'
260
- X_filename = tmpdir / "X.csv"
261
- y_filename = tmpdir / "y.csv"
262
- weights_filename = tmpdir / "weights.csv"
263
 
264
  def_hyperparams = ""
265
 
@@ -463,6 +444,34 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
463
  return get_hof()
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  def raise_depreciation_errors(limitPowComplexity, threads):
467
  if threads is not None:
468
  raise ValueError("The threads kwarg is deprecated. Use procs.")
 
207
  if len(X.shape) == 1:
208
  X = X[:, None]
209
 
210
+ check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
 
 
 
 
 
 
 
 
 
211
 
212
  if select_k_features is not None:
213
  selection = run_feature_selection(X, y, select_k_features)
 
239
  y = eval(eval_str)
240
  print("Running on", eval_str)
241
 
242
+ X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename = set_paths(
243
+ tempdir)
 
 
 
 
 
 
 
 
 
 
244
 
245
  def_hyperparams = ""
246
 
 
444
  return get_hof()
445
 
446
 
447
+ def set_paths(tempdir):
448
+ # System-independent paths
449
+ pkg_directory = Path(__file__).parents[1] / 'julia'
450
+ pkg_filename = pkg_directory / "sr.jl"
451
+ operator_filename = pkg_directory / "operators.jl"
452
+ tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
453
+ hyperparam_filename = tmpdir / f'hyperparams.jl'
454
+ dataset_filename = tmpdir / f'dataset.jl'
455
+ runfile_filename = tmpdir / f'runfile.jl'
456
+ X_filename = tmpdir / "X.csv"
457
+ y_filename = tmpdir / "y.csv"
458
+ weights_filename = tmpdir / "weights.csv"
459
+ return X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename
460
+
461
+
462
+ def check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
463
+ # Check for potential errors before they happen
464
+ assert len(unary_operators) + len(binary_operators) > 0
465
+ assert len(X.shape) == 2
466
+ assert len(y.shape) == 1
467
+ assert X.shape[0] == y.shape[0]
468
+ if weights is not None:
469
+ assert len(weights.shape) == 1
470
+ assert X.shape[0] == weights.shape[0]
471
+ if use_custom_variable_names:
472
+ assert len(variable_names) == X.shape[1]
473
+
474
+
475
  def raise_depreciation_errors(limitPowComplexity, threads):
476
  if threads is not None:
477
  raise ValueError("The threads kwarg is deprecated. Use procs.")