DeepStruc / app_old.py
AndySAnker's picture
Rename app.py to app_old.py
829530b
raw
history blame
10.6 kB
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os
import sys, argparse
# Function to format the data from the ASCII file
def get_data(args):
ct = str(datetime.datetime.now()).replace(' ', '_').replace(':','-').replace('.','-')
project_name = f'{args.save_path}/DeepStruc_{ct}'
if not os.path.isdir(f'{project_name}'):
os.mkdir(f'{project_name}')
this_path = args.data
samples = args.num_samples
if os.path.isdir(this_path):
files = sorted(os.listdir(this_path))
else:
files = [this_path]
this_path = '.'
x_list, y_list, name_list = [], [], []
idxx = 0
np_data = np.zeros((len(files)*samples, 2800))
for idx, file in enumerate(files):
for skip_row in range(100):
try:
data = np.loadtxt(f'{this_path}/{file}', skiprows=skip_row)
except ValueError:
continue
data = data.T
x_list.append(data[0])
y_list.append(data[1])
Gr_ph = data[1]
if round(data[0][1] - data[0][0],2) != 0.01:
raise ValueError("The PDF does not have an r-step of 0.01 Å")
try:
start_PDF = np.where((data[0] > 1.995) & (data[0] < 2.005))[0][0]
except:
Gr_ph = np.concatenate((np.zeros((int((data[0][0])/0.01))), Gr_ph))
try:
end_PDF = np.where((data[0] > 29.995) & (data[0] < 30.005))[0][0]
except:
Gr_ph = np.concatenate((Gr_ph, np.zeros((3000-len(Gr_ph)))))
Gr_ph = Gr_ph[200:3000]
for i in range(samples):
np_data[idxx] = Gr_ph
np_data[idxx] /= np.amax(np_data[idxx])
idxx += 1
name_list.append(file)
break
def get_model(model_dir):
if model_dir == 'DeepStruc':
with open(f'./models/DeepStruc/model_arch.yaml') as file:
model_arch = yaml.full_load(file)
model_path = './models/DeepStruc/models/DeepStruc.ckpt'
return model_path, model_arch
if os.path.isdir(model_dir):
if 'models' in os.listdir(model_dir):
models = sorted(os.listdir(f'{model_dir}/models'))
models = [model for model in models if '.ckpt' in model]
print(f'No specific model was provided. {models[0]} was chosen.')
print('Dataloader might not be sufficient in loading dimensions.')
model_path = f'{model_dir}/models/{models[0]}'
with open(f'{model_dir}/model_arch.yaml') as file:
model_arch = yaml.full_load(file)
return model_path, model_arch
else:
print(f'Path not understood: {model_dir}')
else:
idx = model_dir.rindex('/')
with open(f'{model_dir[:idx-6]}model_arch.yaml') as file:
model_arch = yaml.full_load(file)
return model_dir, model_arch
np_data = np_data.reshape((len(files)*samples, 2800, 1))
return np_data, name_list, project_name
def format_predictions(latent_space, data_names, mus, sigmas, sigma_inc):
df_preds = pd.DataFrame(columns=['x', 'y', 'file_name', 'mu', 'sigma', 'sigma_inc'])
for i,j, mu, sigma in zip(latent_space, data_names, mus, sigmas):
if '/' in j:
j = j.split('/')[-1]
if '.' in j:
j_idx = j.rindex('.')
j = j[:j_idx]
info_dict = {
'x': i[0].detach().cpu().numpy(),
'y': i[1].detach().cpu().numpy(),
'file_name': j,
'mu': mu.detach().cpu().numpy(),
'sigma': sigma.detach().cpu().numpy(),
'sigma_inc': sigma_inc,
}
df_preds = df_preds.append(info_dict, ignore_index=True)
return df_preds
def plot_ls(df, mk_dir, index_highlight):
if not os.path.isdir(mk_dir):
os.mkdir(mk_dir)
ideal_ls = './tools/ls_points.csv'
color_dict = {
'FCC': '#19ADFF',
'BCC': '#4F8F00',
'SC': '#941100',
'Octahedron': '#212121',
'Icosahedron': '#005493',
'Decahedron': '#FF950E',
'HCP': '#FF8AD8',
}
df_ideal = pd.read_csv(ideal_ls, index_col=0) # Get latent space data
# Plotting inputs
## Training and validation data
MARKER_SIZE_TR = 60
EDGE_LINEWIDTH_TR = 0.0
ALPHA_TR = 0.3
## Figure
FIG_SIZE = (10, 4)
MARKER_SIZE_FG = 60
MARKER_FONT_SIZE = 10
MARKER_SCALE = 1.5
fig = plt.figure(figsize=FIG_SIZE)
gs = GridSpec(1, 5, figure=fig)
ax = fig.add_subplot(gs[0, :4])
ax_legend = fig.add_subplot(gs[0, 4])
if index_highlight >= len(df):
print(f'\nIndex argument is to large! Need to be smaller than {len(df)} but was {index_highlight}')
raise IndexError
elif index_highlight < -1:
print(f'\nIndex argument invalid! Must be integer from -1 to number of samples generated.')
raise ValueError
elif index_highlight==-1:
pass
elif len(df['file_name'].unique()) > 1:
print(f'\nCan only show highlight index if --data is specific file but {len(df["file_name"].unique())} files were loaded.')
else:
print(f'\nHighlighting index {index_highlight} from the {df["file_name"].unique()[0]} sampling pool.')
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='k', s=40,
linewidth=0.0, marker='o', zorder=3)
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='w', s=25,
linewidth=0.0, marker='o', zorder=3)
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='k', s=10,
linewidth=0.0, marker='o', zorder=3)
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='w', s=1,
linewidth=0.0, marker='o', zorder=3)
print('\nPlotting DeepStruc training + validation data.')
pbar = tqdm(total=len(df_ideal))
for idx in range(len(df_ideal)):
ax.scatter(df_ideal.iloc[idx]['x'], df_ideal.iloc[idx]['y'],
c=color_dict[df_ideal.iloc[idx]['stru_type']], s=MARKER_SIZE_TR * df_ideal.iloc[idx]['size'],
edgecolors='k', linewidth=EDGE_LINEWIDTH_TR,
alpha=ALPHA_TR)
pbar.update()
pbar.close()
mlines_list = []
for key in color_dict.keys():
mlines_list.append(
mlines.Line2D([], [], MARKER_SIZE_FG, marker='o', c=color_dict[key], linestyle='None', label=key,
mew=1))
from matplotlib import cm
cm_subsection = np.linspace(0, 1, len(df.file_name.unique()))
data_color = [cm.magma(x) for x in cm_subsection]
print('\nPlotting DeepStruc structure sampling.')
pbar = tqdm(total=len(df.file_name.unique()))
for idx, file_name in enumerate(df.file_name.unique()):
this_c = np.array([data_color[idx]])
df_ph = df[df.file_name==file_name]
df_ph.reset_index(drop=True, inplace=True)
ax.scatter(df_ph['mu'][0][0],df_ph['mu'][0][1], c=this_c, s=10, edgecolors='k',
linewidth=0.5, marker='D',zorder=1)
ellipse = Ellipse((df_ph['mu'][0][0],df_ph['mu'][0][1]),df_ph['sigma'][0][0],df_ph['sigma'][0][1], ec='k', fc=this_c, alpha=0.5, fill=True, zorder=-1)
ax.add_patch(ellipse)
ellipse = Ellipse((df_ph['mu'][0][0],df_ph['mu'][0][1]),df_ph['x'].var(),df_ph['y'].var(), ec='k', fc=this_c, alpha=0.2, fill=True, zorder=-1)
ax.add_patch(ellipse)
mlines_list.append(
mlines.Line2D([], [], MARKER_SIZE_FG, marker='D', c=this_c, linestyle='None', label=file_name, mec='k',
mew=1))
for index, sample in df_ph.iterrows():
ax.scatter(sample['x'], sample['y'], c=this_c, s=10, edgecolors='k',
linewidth=0.8, marker='o', zorder=2)
pbar.update()
pbar.close()
ax_legend.legend(handles=mlines_list,fancybox=True, #ncol=2, #, bbox_to_anchor=(0.8, 0.5)
markerscale=MARKER_SCALE, fontsize=MARKER_FONT_SIZE, loc='upper right')
ax.set_xlabel('Latent space $\mathregular{z_0}$', size=10) # Latent Space Feature 1
ax.set_ylabel('Latent space $\mathregular{z_1}$', size=10)
ax_legend.spines['top'].set_visible(False)
ax_legend.spines['right'].set_visible(False)
ax_legend.spines['bottom'].set_visible(False)
ax_legend.spines['left'].set_visible(False)
ax_legend.get_xaxis().set_ticks([])
ax_legend.get_yaxis().set_ticks([])
ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
plt.tight_layout()
plt.savefig(f'{mk_dir}/ls.png',dpi=300)
return None
def save_predictions(xyz_pred, df, project_name, model_arch, args):
print('\nSaving predicted structures as XYZ files.')
if not os.path.isdir(f'{project_name}'):
os.mkdir(f'{project_name}')
with open(f'{project_name}/args.yaml', 'w') as outfile:
yaml.dump(vars(args), outfile, allow_unicode=True, default_flow_style=False)
pbar = tqdm(total=len(df))
for count, (idx, row) in enumerate(df.iterrows()):
if not os.path.isdir(f'{project_name}/{row["file_name"]}'):
os.mkdir(f'{project_name}/{row["file_name"]}')
x = f'{float(row["x"]):+.3f}'.replace('.', '-')
y = f'{float(row["y"]):+.3f}'.replace('.', '-')
save_xyz_file(f'{project_name}/{row["file_name"]}',
xyz_pred[idx].detach().cpu().numpy(),
f'{row["file_name"]}_{count:05}_ls_{x}_{y}',
[model_arch['norm_vals']['x'],model_arch['norm_vals']['y'],model_arch['norm_vals']['z']])
pbar.update()
pbar.close()
return None
# Allow the user to upload an ASCII file from their local computer
PDFFile = st.file_uploader("Upload your ASCII file")
data, data_name, project_name = get_data(args)
# Load model
model_path, model_arch = get_model(args.model)
Net(model_arch=model_arch)
DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch)
xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma)
samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma)
if args.plot_sampling == True and args.model == 'DeepStruc':
plot_ls(samling_pairs, project_name, args.index_plot)
elif args.plot_sampling == True and args.model != 'DeepStruc':
print("Argument '--model' needs to be default DeepStruc value for plot to be generated!")
save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args)