Spaces:
Running
Running
AutonLabTruth
commited on
Commit
•
181a454
1
Parent(s):
762987c
Refactored till handle_constraints
Browse files- pysr/sr.py +73 -55
pysr/sr.py
CHANGED
@@ -207,75 +207,29 @@ def pysr(X=None, y=None, weights=None,
|
|
207 |
|
208 |
check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
|
209 |
|
|
|
|
|
210 |
if maxdepth is None:
|
211 |
maxdepth = maxsize
|
212 |
if equation_file is None:
|
213 |
date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
|
214 |
equation_file = 'hall_of_fame_' + date_time + '.csv'
|
215 |
-
|
216 |
-
if select_k_features is not None:
|
217 |
-
selection = run_feature_selection(X, y, select_k_features)
|
218 |
-
print(f"Using features {selection}")
|
219 |
-
X = X[:, selection]
|
220 |
-
|
221 |
-
if use_custom_variable_names:
|
222 |
-
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
223 |
-
|
224 |
if populations is None:
|
225 |
populations = procs
|
226 |
-
|
227 |
-
|
228 |
-
if isinstance(unary_operators, str):
|
229 |
-
|
230 |
if X is None:
|
231 |
-
|
232 |
-
eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
|
233 |
-
elif test == 'simple2':
|
234 |
-
eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
|
235 |
-
elif test == 'simple3':
|
236 |
-
eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
|
237 |
-
elif test == 'simple4':
|
238 |
-
eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
|
239 |
-
elif test == 'simple5':
|
240 |
-
eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
|
241 |
-
|
242 |
-
X = np.random.randn(100, 5)*3
|
243 |
-
y = eval(eval_str)
|
244 |
-
print("Running on", eval_str)
|
245 |
|
246 |
def_hyperparams = ""
|
247 |
|
248 |
# Add pre-defined functions to Julia
|
249 |
-
|
250 |
-
for i in range(len(op_list)):
|
251 |
-
op = op_list[i]
|
252 |
-
is_user_defined_operator = '(' in op
|
253 |
-
|
254 |
-
if is_user_defined_operator:
|
255 |
-
def_hyperparams += op + "\n"
|
256 |
-
# Cut off from the first non-alphanumeric char:
|
257 |
-
first_non_char = [
|
258 |
-
j for j in range(len(op))
|
259 |
-
if not (op[j].isalpha() or op[j].isdigit())][0]
|
260 |
-
function_name = op[:first_non_char]
|
261 |
-
op_list[i] = function_name
|
262 |
|
263 |
#arbitrary complexity by default
|
264 |
-
|
265 |
-
if op not in constraints:
|
266 |
-
constraints[op] = -1
|
267 |
-
for op in binary_operators:
|
268 |
-
if op not in constraints:
|
269 |
-
constraints[op] = (-1, -1)
|
270 |
-
if op in ['plus', 'sub']:
|
271 |
-
if constraints[op][0] != constraints[op][1]:
|
272 |
-
raise NotImplementedError("You need equal constraints on both sides for - and *, due to simplification strategies.")
|
273 |
-
elif op == 'mult':
|
274 |
-
# Make sure the complex expression is in the left side.
|
275 |
-
if constraints[op][0] == -1:
|
276 |
-
continue
|
277 |
-
elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
|
278 |
-
constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
|
279 |
|
280 |
constraints_str = "const una_constraints = ["
|
281 |
first = True
|
@@ -445,6 +399,70 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
|
|
445 |
return get_hof()
|
446 |
|
447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
def set_paths(tempdir):
|
449 |
# System-independent paths
|
450 |
pkg_directory = Path(__file__).parents[1] / 'julia'
|
|
|
207 |
|
208 |
check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
|
209 |
|
210 |
+
X, variable_names = handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y)
|
211 |
+
|
212 |
if maxdepth is None:
|
213 |
maxdepth = maxsize
|
214 |
if equation_file is None:
|
215 |
date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
|
216 |
equation_file = 'hall_of_fame_' + date_time + '.csv'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
if populations is None:
|
218 |
populations = procs
|
219 |
+
if isinstance(binary_operators, str):
|
220 |
+
binary_operators = [binary_operators]
|
221 |
+
if isinstance(unary_operators, str):
|
222 |
+
unary_operators = [unary_operators]
|
223 |
if X is None:
|
224 |
+
X, y = using_test_input(X, test, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
def_hyperparams = ""
|
227 |
|
228 |
# Add pre-defined functions to Julia
|
229 |
+
def_hyperparams = predefined_function_addition(binary_operators, def_hyperparams, unary_operators)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
#arbitrary complexity by default
|
232 |
+
handle_constraints(binary_operators, constraints, unary_operators)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
constraints_str = "const una_constraints = ["
|
235 |
first = True
|
|
|
399 |
return get_hof()
|
400 |
|
401 |
|
402 |
+
def handle_constraints(binary_operators, constraints, unary_operators):
|
403 |
+
for op in unary_operators:
|
404 |
+
if op not in constraints:
|
405 |
+
constraints[op] = -1
|
406 |
+
for op in binary_operators:
|
407 |
+
if op not in constraints:
|
408 |
+
constraints[op] = (-1, -1)
|
409 |
+
if op in ['plus', 'sub']:
|
410 |
+
if constraints[op][0] != constraints[op][1]:
|
411 |
+
raise NotImplementedError(
|
412 |
+
"You need equal constraints on both sides for - and *, due to simplification strategies.")
|
413 |
+
elif op == 'mult':
|
414 |
+
# Make sure the complex expression is in the left side.
|
415 |
+
if constraints[op][0] == -1:
|
416 |
+
continue
|
417 |
+
elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
|
418 |
+
constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
|
419 |
+
|
420 |
+
|
421 |
+
def predefined_function_addition(binary_operators, def_hyperparams, unary_operators):
|
422 |
+
for op_list in [binary_operators, unary_operators]:
|
423 |
+
for i in range(len(op_list)):
|
424 |
+
op = op_list[i]
|
425 |
+
is_user_defined_operator = '(' in op
|
426 |
+
|
427 |
+
if is_user_defined_operator:
|
428 |
+
def_hyperparams += op + "\n"
|
429 |
+
# Cut off from the first non-alphanumeric char:
|
430 |
+
first_non_char = [
|
431 |
+
j for j in range(len(op))
|
432 |
+
if not (op[j].isalpha() or op[j].isdigit())][0]
|
433 |
+
function_name = op[:first_non_char]
|
434 |
+
op_list[i] = function_name
|
435 |
+
return def_hyperparams
|
436 |
+
|
437 |
+
|
438 |
+
def using_test_input(X, test, y):
|
439 |
+
if test == 'simple1':
|
440 |
+
eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
|
441 |
+
elif test == 'simple2':
|
442 |
+
eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
|
443 |
+
elif test == 'simple3':
|
444 |
+
eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
|
445 |
+
elif test == 'simple4':
|
446 |
+
eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
|
447 |
+
elif test == 'simple5':
|
448 |
+
eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
|
449 |
+
X = np.random.randn(100, 5) * 3
|
450 |
+
y = eval(eval_str)
|
451 |
+
print("Running on", eval_str)
|
452 |
+
return X, y
|
453 |
+
|
454 |
+
|
455 |
+
def handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y):
|
456 |
+
if select_k_features is not None:
|
457 |
+
selection = run_feature_selection(X, y, select_k_features)
|
458 |
+
print(f"Using features {selection}")
|
459 |
+
X = X[:, selection]
|
460 |
+
|
461 |
+
if use_custom_variable_names:
|
462 |
+
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
463 |
+
return X, variable_names
|
464 |
+
|
465 |
+
|
466 |
def set_paths(tempdir):
|
467 |
# System-independent paths
|
468 |
pkg_directory = Path(__file__).parents[1] / 'julia'
|