steadystate-mantle / calculate_profiles.py
Agarwal
update corrections and network
07df88c
raw
history blame
2.98 kB
from matplotlib import pyplot as plt
from utils import *
import warnings
#### Define outputs ####
write_file = True
plot_profile = True
#### Define outputs ####
with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file:
mlp = pickle.load(file)
f_nn = "my_simulation_parameters.txt"
with open(f_nn) as fw:
lines = fw.readlines()
for line in lines:
l = line.rstrip()
if "r_list" in l:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
r_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
elif "t_list" in line:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
t_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
elif "v_list" in line:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
v_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
raise Exception("Ensure equal number of values for all parameters in " + f_nn)
for i in range(len(r_list)):
if r_list[i]<0 or r_list[i]>9.5:
warnings.warn('RaQ/Ra is outside the range of the training dataset')
if t_list[i]<1e+6 or t_list[i]>5e+9:
warnings.warn('FKT is outside the range of the training dataset')
if v_list[i]<1 or v_list[i]>95:
warnings.warn('FKV is outside the range of the training dataset')
### calculates y points ###
num_points = 128
y_prof = np.concatenate(([1], np.linspace(0+1/(num_points*2),1-1/(num_points*2),num_points-2)[::-1], [0]))
### calculates y points ###
### calculates temperature profile ###
x_in = get_input(r_list, t_list, v_list, y_prof)
y_pred_nn_pointwise = get_profile(x_in, mlp, num_sims=len(r_list))
### calculates temperature profile ###
### writes out temperature profile ###
if write_file:
for i in range(len(r_list)):
fname = "outputs/profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i])
f = open(fname + ".txt", "wb")
for j in range(len(y_prof)):
f.writelines([str(y_prof[j]).encode('ascii'),
" ".encode('ascii'),
str(y_pred_nn_pointwise[i,j]).encode('ascii'),
"\n".encode('ascii')])
f.close()
### writes out temperature profile ###
### plots temperature profile ###
for i in range(len(r_list)):
fname = "outputs/profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i])
plt.figure()
plt.plot(y_pred_nn_pointwise[i,:], y_prof, 'k-', linewidth=3.0, label="pointwise neural network")
plt.ylim([1,0])
plt.xlabel("Temperature")
plt.ylabel("Depth")
plt.legend()
plt.grid()
plt.savefig(fname + ".png")
### plots temperature profile ###