Spaces:
Sleeping
Sleeping
import streamlit as st | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import datetime | |
import os | |
import sys, argparse | |
# Function to format the data from the ASCII file | |
def get_data(args): | |
ct = str(datetime.datetime.now()).replace(' ', '_').replace(':','-').replace('.','-') | |
project_name = f'{args.save_path}/DeepStruc_{ct}' | |
if not os.path.isdir(f'{project_name}'): | |
os.mkdir(f'{project_name}') | |
this_path = args.data | |
samples = args.num_samples | |
if os.path.isdir(this_path): | |
files = sorted(os.listdir(this_path)) | |
else: | |
files = [this_path] | |
this_path = '.' | |
x_list, y_list, name_list = [], [], [] | |
idxx = 0 | |
np_data = np.zeros((len(files)*samples, 2800)) | |
for idx, file in enumerate(files): | |
for skip_row in range(100): | |
try: | |
data = np.loadtxt(f'{this_path}/{file}', skiprows=skip_row) | |
except ValueError: | |
continue | |
data = data.T | |
x_list.append(data[0]) | |
y_list.append(data[1]) | |
Gr_ph = data[1] | |
if round(data[0][1] - data[0][0],2) != 0.01: | |
raise ValueError("The PDF does not have an r-step of 0.01 Å") | |
try: | |
start_PDF = np.where((data[0] > 1.995) & (data[0] < 2.005))[0][0] | |
except: | |
Gr_ph = np.concatenate((np.zeros((int((data[0][0])/0.01))), Gr_ph)) | |
try: | |
end_PDF = np.where((data[0] > 29.995) & (data[0] < 30.005))[0][0] | |
except: | |
Gr_ph = np.concatenate((Gr_ph, np.zeros((3000-len(Gr_ph))))) | |
Gr_ph = Gr_ph[200:3000] | |
for i in range(samples): | |
np_data[idxx] = Gr_ph | |
np_data[idxx] /= np.amax(np_data[idxx]) | |
idxx += 1 | |
name_list.append(file) | |
break | |
def get_model(model_dir): | |
if model_dir == 'DeepStruc': | |
with open(f'./models/DeepStruc/model_arch.yaml') as file: | |
model_arch = yaml.full_load(file) | |
model_path = './models/DeepStruc/models/DeepStruc.ckpt' | |
return model_path, model_arch | |
if os.path.isdir(model_dir): | |
if 'models' in os.listdir(model_dir): | |
models = sorted(os.listdir(f'{model_dir}/models')) | |
models = [model for model in models if '.ckpt' in model] | |
print(f'No specific model was provided. {models[0]} was chosen.') | |
print('Dataloader might not be sufficient in loading dimensions.') | |
model_path = f'{model_dir}/models/{models[0]}' | |
with open(f'{model_dir}/model_arch.yaml') as file: | |
model_arch = yaml.full_load(file) | |
return model_path, model_arch | |
else: | |
print(f'Path not understood: {model_dir}') | |
else: | |
idx = model_dir.rindex('/') | |
with open(f'{model_dir[:idx-6]}model_arch.yaml') as file: | |
model_arch = yaml.full_load(file) | |
return model_dir, model_arch | |
np_data = np_data.reshape((len(files)*samples, 2800, 1)) | |
return np_data, name_list, project_name | |
def format_predictions(latent_space, data_names, mus, sigmas, sigma_inc): | |
df_preds = pd.DataFrame(columns=['x', 'y', 'file_name', 'mu', 'sigma', 'sigma_inc']) | |
for i,j, mu, sigma in zip(latent_space, data_names, mus, sigmas): | |
if '/' in j: | |
j = j.split('/')[-1] | |
if '.' in j: | |
j_idx = j.rindex('.') | |
j = j[:j_idx] | |
info_dict = { | |
'x': i[0].detach().cpu().numpy(), | |
'y': i[1].detach().cpu().numpy(), | |
'file_name': j, | |
'mu': mu.detach().cpu().numpy(), | |
'sigma': sigma.detach().cpu().numpy(), | |
'sigma_inc': sigma_inc, | |
} | |
df_preds = df_preds.append(info_dict, ignore_index=True) | |
return df_preds | |
def plot_ls(df, mk_dir, index_highlight): | |
if not os.path.isdir(mk_dir): | |
os.mkdir(mk_dir) | |
ideal_ls = './tools/ls_points.csv' | |
color_dict = { | |
'FCC': '#19ADFF', | |
'BCC': '#4F8F00', | |
'SC': '#941100', | |
'Octahedron': '#212121', | |
'Icosahedron': '#005493', | |
'Decahedron': '#FF950E', | |
'HCP': '#FF8AD8', | |
} | |
df_ideal = pd.read_csv(ideal_ls, index_col=0) # Get latent space data | |
# Plotting inputs | |
## Training and validation data | |
MARKER_SIZE_TR = 60 | |
EDGE_LINEWIDTH_TR = 0.0 | |
ALPHA_TR = 0.3 | |
## Figure | |
FIG_SIZE = (10, 4) | |
MARKER_SIZE_FG = 60 | |
MARKER_FONT_SIZE = 10 | |
MARKER_SCALE = 1.5 | |
fig = plt.figure(figsize=FIG_SIZE) | |
gs = GridSpec(1, 5, figure=fig) | |
ax = fig.add_subplot(gs[0, :4]) | |
ax_legend = fig.add_subplot(gs[0, 4]) | |
if index_highlight >= len(df): | |
print(f'\nIndex argument is to large! Need to be smaller than {len(df)} but was {index_highlight}') | |
raise IndexError | |
elif index_highlight < -1: | |
print(f'\nIndex argument invalid! Must be integer from -1 to number of samples generated.') | |
raise ValueError | |
elif index_highlight==-1: | |
pass | |
elif len(df['file_name'].unique()) > 1: | |
print(f'\nCan only show highlight index if --data is specific file but {len(df["file_name"].unique())} files were loaded.') | |
else: | |
print(f'\nHighlighting index {index_highlight} from the {df["file_name"].unique()[0]} sampling pool.') | |
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='k', s=40, | |
linewidth=0.0, marker='o', zorder=3) | |
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='w', s=25, | |
linewidth=0.0, marker='o', zorder=3) | |
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='k', s=10, | |
linewidth=0.0, marker='o', zorder=3) | |
ax.scatter(df.iloc[index_highlight]['x'], df.iloc[index_highlight]['y'], c='w', s=1, | |
linewidth=0.0, marker='o', zorder=3) | |
print('\nPlotting DeepStruc training + validation data.') | |
pbar = tqdm(total=len(df_ideal)) | |
for idx in range(len(df_ideal)): | |
ax.scatter(df_ideal.iloc[idx]['x'], df_ideal.iloc[idx]['y'], | |
c=color_dict[df_ideal.iloc[idx]['stru_type']], s=MARKER_SIZE_TR * df_ideal.iloc[idx]['size'], | |
edgecolors='k', linewidth=EDGE_LINEWIDTH_TR, | |
alpha=ALPHA_TR) | |
pbar.update() | |
pbar.close() | |
mlines_list = [] | |
for key in color_dict.keys(): | |
mlines_list.append( | |
mlines.Line2D([], [], MARKER_SIZE_FG, marker='o', c=color_dict[key], linestyle='None', label=key, | |
mew=1)) | |
from matplotlib import cm | |
cm_subsection = np.linspace(0, 1, len(df.file_name.unique())) | |
data_color = [cm.magma(x) for x in cm_subsection] | |
print('\nPlotting DeepStruc structure sampling.') | |
pbar = tqdm(total=len(df.file_name.unique())) | |
for idx, file_name in enumerate(df.file_name.unique()): | |
this_c = np.array([data_color[idx]]) | |
df_ph = df[df.file_name==file_name] | |
df_ph.reset_index(drop=True, inplace=True) | |
ax.scatter(df_ph['mu'][0][0],df_ph['mu'][0][1], c=this_c, s=10, edgecolors='k', | |
linewidth=0.5, marker='D',zorder=1) | |
ellipse = Ellipse((df_ph['mu'][0][0],df_ph['mu'][0][1]),df_ph['sigma'][0][0],df_ph['sigma'][0][1], ec='k', fc=this_c, alpha=0.5, fill=True, zorder=-1) | |
ax.add_patch(ellipse) | |
ellipse = Ellipse((df_ph['mu'][0][0],df_ph['mu'][0][1]),df_ph['x'].var(),df_ph['y'].var(), ec='k', fc=this_c, alpha=0.2, fill=True, zorder=-1) | |
ax.add_patch(ellipse) | |
mlines_list.append( | |
mlines.Line2D([], [], MARKER_SIZE_FG, marker='D', c=this_c, linestyle='None', label=file_name, mec='k', | |
mew=1)) | |
for index, sample in df_ph.iterrows(): | |
ax.scatter(sample['x'], sample['y'], c=this_c, s=10, edgecolors='k', | |
linewidth=0.8, marker='o', zorder=2) | |
pbar.update() | |
pbar.close() | |
ax_legend.legend(handles=mlines_list,fancybox=True, #ncol=2, #, bbox_to_anchor=(0.8, 0.5) | |
markerscale=MARKER_SCALE, fontsize=MARKER_FONT_SIZE, loc='upper right') | |
ax.set_xlabel('Latent space $\mathregular{z_0}$', size=10) # Latent Space Feature 1 | |
ax.set_ylabel('Latent space $\mathregular{z_1}$', size=10) | |
ax_legend.spines['top'].set_visible(False) | |
ax_legend.spines['right'].set_visible(False) | |
ax_legend.spines['bottom'].set_visible(False) | |
ax_legend.spines['left'].set_visible(False) | |
ax_legend.get_xaxis().set_ticks([]) | |
ax_legend.get_yaxis().set_ticks([]) | |
ax.get_xaxis().set_ticks([]) | |
ax.get_yaxis().set_ticks([]) | |
plt.tight_layout() | |
plt.savefig(f'{mk_dir}/ls.png',dpi=300) | |
return None | |
def save_predictions(xyz_pred, df, project_name, model_arch, args): | |
print('\nSaving predicted structures as XYZ files.') | |
if not os.path.isdir(f'{project_name}'): | |
os.mkdir(f'{project_name}') | |
with open(f'{project_name}/args.yaml', 'w') as outfile: | |
yaml.dump(vars(args), outfile, allow_unicode=True, default_flow_style=False) | |
pbar = tqdm(total=len(df)) | |
for count, (idx, row) in enumerate(df.iterrows()): | |
if not os.path.isdir(f'{project_name}/{row["file_name"]}'): | |
os.mkdir(f'{project_name}/{row["file_name"]}') | |
x = f'{float(row["x"]):+.3f}'.replace('.', '-') | |
y = f'{float(row["y"]):+.3f}'.replace('.', '-') | |
save_xyz_file(f'{project_name}/{row["file_name"]}', | |
xyz_pred[idx].detach().cpu().numpy(), | |
f'{row["file_name"]}_{count:05}_ls_{x}_{y}', | |
[model_arch['norm_vals']['x'],model_arch['norm_vals']['y'],model_arch['norm_vals']['z']]) | |
pbar.update() | |
pbar.close() | |
return None | |
# Allow the user to upload an ASCII file from their local computer | |
PDFFile = st.file_uploader("Upload your ASCII file") | |
data, data_name, project_name = get_data(args) | |
# Load model | |
model_path, model_arch = get_model(args.model) | |
Net(model_arch=model_arch) | |
DeepStruc = Net.load_from_checkpoint(model_path,model_arch=model_arch) | |
xyz_pred, latent_space, kl, mu, sigma = DeepStruc(data, mode='prior', sigma_scale=args.sigma) | |
samling_pairs = format_predictions(latent_space, data_name, mu, sigma, args.sigma) | |
if args.plot_sampling == True and args.model == 'DeepStruc': | |
plot_ls(samling_pairs, project_name, args.index_plot) | |
elif args.plot_sampling == True and args.model != 'DeepStruc': | |
print("Argument '--model' needs to be default DeepStruc value for plot to be generated!") | |
save_predictions(xyz_pred, samling_pairs, project_name, model_arch, args) | |