steadystate-mantle / calculate_baselines.py
agsiddhant's picture
more_baselining (#1)
893b843 verified
from matplotlib import pyplot as plt
from utils import *
import warnings
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KDTree
from sklearn.neighbors import NearestNeighbors
#### Define outputs ####
write_file = True
plot_profile = True
#### Define outputs ####
# load data
pre = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/profiles/"
with open(pre + 'x_p.pkl', 'rb') as file:
x_p = pickle.load(file)
with open(pre + 'y_p.pkl', 'rb') as file:
y_p = pickle.load(file)
# calculate baselines
linear_reg = LinearRegression().fit(x_p["train"], y_p["train"])
krr = KernelRidge(alpha=0.01, kernel="rbf")
krr.fit(x_p["train"], y_p["train"])
tree = KDTree(x_p["train"], leaf_size=2)
nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(x_p["train"])
# read the simulation parameters and ensure correct foramtting
f_nn = "my_simulation_parameters.txt"
with open(f_nn) as fw:
lines = fw.readlines()
for line in lines:
l = line.rstrip()
if "r_list" in l:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
r_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
elif "t_list" in line:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
t_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
elif "v_list" in line:
if not l[-1] == ",":
raise Exception("Ensure there is a comma after last parameter value in " + f_nn)
v_list = [float(p) for p in l.split("=")[1].split(",")[:-1]]
if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
raise Exception("Ensure equal number of values for all parameters in " + f_nn)
# Check parameter ranges and print appropriate warnings
for i in range(len(r_list)):
if r_list[i]<0 or r_list[i]>9.5:
warnings.warn('RaQ/Ra is outside the range of the training dataset')
if t_list[i]<1e+6 or t_list[i]>5e+9:
warnings.warn('FKT is outside the range of the training dataset')
if v_list[i]<1 or v_list[i]>95:
warnings.warn('FKV is outside the range of the training dataset')
### calculates y points ###
num_points = 128
y_prof = np.concatenate(([1], np.linspace(0+1/(num_points*2),1-1/(num_points*2),num_points-2)[::-1], [0]))
### calculates y points ###
### calculates temperature profile ###
x_in = get_input(r_list, t_list, v_list)
y_pred = {}
y_pred["linear"] = linear_reg.predict(x_in)
y_pred["krr"] = krr.predict(x_in)
distances, indices = nbrs.kneighbors(x_in)
y_pred["neighbor"] = y_p["train"][indices,:].reshape(x_in.shape[0],128)
dist, ind = tree.query(x_in, k=3)
y_pred["interp"] = np.zeros_like(y_pred["krr"])
for i in range(y_pred["interp"].shape[0]):
weights = 1.0/dist[i,:]
weights /= weights.sum(axis=0)
y_pred["interp"][i,:] = y_p["train"][ind[i,0],:]*weights[0] \
+ y_p["train"][ind[i,1],:]*weights[1] \
+ y_p["train"][ind[i,2],:]*weights[2]
### calculates temperature profile ###
### writes out temperature profile ###
if write_file:
for algorithm in ["linear", "krr", "interp", "neighbor"]:
for i in range(len(r_list)):
fname = "outputs/" + algorithm + "_profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i])
f = open(fname + ".txt", "wb")
for j in range(len(y_prof)):
f.writelines([str(y_prof[j]).encode('ascii'),
" ".encode('ascii'),
str(y_pred[algorithm][i,j]).encode('ascii'),
"\n".encode('ascii')])
f.close()
### writes out temperature profile ###
### plots temperature profile ###
for algorithm in ["linear", "krr", "interp", "neighbor"]:
for i in range(len(r_list)):
fname = "outputs/" + algorithm + "_profile_raq_ra" + str(r_list[i]) + "_fkt" + str(t_list[i]) + "_fkv" + str(v_list[i])
plt.figure()
plt.plot(y_pred[algorithm][i,:], y_prof, 'k-', linewidth=3.0, label="pointwise neural network")
plt.ylim([1,0])
plt.xlabel("Temperature")
plt.ylabel("Depth")
plt.legend()
plt.grid()
plt.savefig(fname + ".png")
### plots temperature profile ###