from matplotlib import pyplot as plt from utils import * import warnings #### Define outputs #### write_file = True plot_profile = True #### Define outputs #### # load the saved mlp with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file: mlp = pickle.load(file) # read the simulation parameters and ensure correct foramtting f_nn = "my_simulation_parameters.txt" with open(f_nn) as fw: lines = fw.readlines() for line in lines: l = line.rstrip() if "r_list" in l: if not l[-1] == ",": raise Exception("Ensure there is a comma after last parameter value in " + f_nn) r_list = [float(p) for p in l.split("=")[1].split(",")[:-1]] elif "t_list" in line: if not l[-1] == ",": raise Exception("Ensure there is a comma after last parameter value in " + f_nn) t_list = [float(p) for p in l.split("=")[1].split(",")[:-1]] elif "v_list" in line: if not l[-1] == ",": raise Exception("Ensure there is a comma after last parameter value in " + f_nn) v_list = [float(p) for p in l.split("=")[1].split(",")[:-1]] if not len(r_list) == len(v_list) and len(r_list) == len(t_list): raise Exception("Ensure equal number of values for all parameters in " + f_nn) # Check parameter ranges and print appropriate warnings for i in range(len(r_list)): if r_list[i]<0 or r_list[i]>9.5: warnings.warn('RaQ/Ra is outside the range of the training dataset') if t_list[i]<1e+6 or t_list[i]>5e+9: warnings.warn('FKT is outside the range of the training dataset') if v_list[i]<1 or v_list[i]>95: warnings.warn('FKV is outside the range of the training dataset') ### calculates y points ### num_points = 128 y_prof = np.concatenate(([1], np.linspace(0+1/(num_points*2),1-1/(num_points*2),num_points-2)[::-1], [0])) ### calculates y points ### ### calculates temperature profile ### x_in = get_input(r_list, t_list, v_list, y_prof) y_pred_nn_pointwise = get_profile(x_in, mlp, num_sims=len(r_list)) ### calculates temperature profile ### ### writes out temperature profile ### if write_file: for i in range(len(r_list)): fname = "outputs/profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i]) f = open(fname + ".txt", "wb") for j in range(len(y_prof)): f.writelines([str(y_prof[j]).encode('ascii'), " ".encode('ascii'), str(y_pred_nn_pointwise[i,j]).encode('ascii'), "\n".encode('ascii')]) f.close() ### writes out temperature profile ### ### plots temperature profile ### for i in range(len(r_list)): fname = "outputs/profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i]) plt.figure() plt.plot(y_pred_nn_pointwise[i,:], y_prof, 'k-', linewidth=3.0, label="pointwise neural network") plt.ylim([1,0]) plt.xlabel("Temperature") plt.ylabel("Depth") plt.legend() plt.grid() plt.savefig(fname + ".png") ### plots temperature profile ###