Spaces:
Runtime error
Runtime error
""""""""""""""""""""""""""""""""" | |
This file is for running. | |
Do not modify this file. | |
For running: DiffEqnSolver.py | |
For modifying: settings.py | |
""""""""""""""""""""""""""""""""" | |
import os | |
import time | |
import DataUtils | |
import Settings as settings | |
from SymbolicFunctionLearner import SFL | |
settings.mode = "de" | |
current_model = SFL() | |
if not os.path.exists('images'): | |
os.makedirs('images') | |
print('\nBeginning experiment: {}'.format(current_model.name)) | |
print("{} tree layers.".format(settings.n_tree_layers)) | |
print("{} features of {} component(s) each.".format(settings.num_features, settings.num_dims_per_feature)) | |
print("{} components in output.".format(settings.n_dims_in_output)) | |
print("{} operators: {}.".format(len(current_model.function_set), | |
current_model.function_set)) | |
train_errors = [] | |
valid_errors = [] | |
test_errors = [] | |
true_eqns = [] | |
train_X = DataUtils.generate_data(settings.train_N, n_vars=current_model.n_input_variables, | |
avoid_zero=settings.avoid_zero) | |
valid_X = DataUtils.generate_data(settings.train_N, n_vars=current_model.n_input_variables, | |
avoid_zero=settings.avoid_zero) | |
test_X = DataUtils.generate_data(settings.test_N, n_vars=current_model.n_input_variables, | |
min_x=settings.test_scope[0], | |
max_x=settings.test_scope[1]) | |
print("\n========================") | |
print("Starting Solver.") | |
print("==========================\n") | |
# Train the model from scratch several times, keeping the best one. | |
start_time = time.time() | |
best_model, best_iter, best_err = current_model.repeat_train(train_X, | |
num_repeats=settings.num_train_repeat_processes, | |
test_x=test_X) | |
running_time = time.time() - start_time | |
print("best_model: {}".format(best_model)) | |
print("----------------------") | |
print("Finished DE. Took {:.2f} minutes.\n".format(running_time / 60)) | |
print("Final solution found at attempt {}:".format(best_iter)) | |
print("y = {}".format(best_model)) | |
print("Test error: {}".format(best_err)) | |
if best_err < 0.02: | |
print("Attained error less than 0.02 - great!") | |
print() | |