Spaces:
Sleeping
Sleeping
File size: 7,516 Bytes
fba5f30 ce48cf1 fba5f30 9184a71 9000099 ce48cf1 fba5f30 d914a39 fba5f30 a9d689d fba5f30 8b24986 fba5f30 ce48cf1 fba5f30 a9d689d 5fd4554 a9d689d fba5f30 0942112 fba5f30 ce48cf1 6fe2177 ce48cf1 066ac37 d914a39 9184a71 066ac37 d8ff1f1 9184a71 066ac37 5fd4554 d8ff1f1 066ac37 739f128 5f7e2e9 066ac37 6fe2177 95f7f6a 066ac37 9184a71 066ac37 b26bcc3 066ac37 e5535aa 066ac37 e5535aa 066ac37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from sklearn.preprocessing import OrdinalEncoder
import xgboost as xgb
import numpy as np
import matplotlib.pyplot as plt
import argparse, h5py, os, re
import streamlit as st
def get_POMFinder():
# Get file paths
load_files = "Backend/"
DataBase_path = "Backend/POMFinder_443structures_100Dataset_per_Structure_xPDF_hypercube_sampling_Grmax_Name.h5"
POMFinder_path = "Backend/XGBoost_443structures_100PDFperStructure_xPDF_hypercube_sampling_Grmax.model"
# Import the Database
hf_name = h5py.File(DataBase_path, "r")
y = hf_name.get('y')
enc = OrdinalEncoder()
y_onehotenc_cat = enc.fit(np.array(y))
y_onehotenc_values = enc.fit_transform(np.array(y))
# Import POMFinder
POMFinder = xgb.XGBClassifier()
POMFinder.load_model(POMFinder_path)
return y, y_onehotenc_cat, y_onehotenc_values, POMFinder
def PDF_Preparation(Your_PDF_Name, Qmin, Qmax, Qdamp, rmax, nyquist):
for i in range(1000):
with open(Your_PDF_Name, "r") as file:
data = file.read().splitlines(True)
if len(data[0]) == 0:
with open(Your_PDF_Name, 'w') as fout:
fout.writelines(data[1:])
break
first_line = data[0]
if len(first_line) > 3 and re.match(r'^-?\d+(?:\.\d+)?$', first_line[0]) != None and re.match(r'^-?\d+(?:\.\d+)?$', first_line[1]) == None and re.match(r'^-?\d+(?:\.\d+)?$', first_line[2]) != None:
PDF = np.loadtxt(Your_PDF_Name)
break
else:
with open(Your_PDF_Name, 'w') as fout:
fout.writelines(data[1:])
r, Gr = PDF[:,0], PDF[:,1]
if r[0] != 0: # In the case that the Data not start at 0.
Gr = Gr[np.where(r==1)[0][0]:] # Remove Data from 0 to 0.5 AA
Gr = Gr[::10] # Nyquist sample the rest of the Data
Gr = np.concatenate(([0,0,0,0,0,0,0,0,0,0], Gr), axis=0) # Concatenate 0 - 0.5 AA on the Gr.
if not nyquist:
Gr = Gr[::10] # Pseudo Nyquist sample Data
if len(Gr) >= (rmax*10+1):
Gr = Gr[:(rmax*10+1)] # In the case Data is up to more than 30 AA, we do not use it.
else:
Gr = np.concatenate((Gr, np.zeros((101-len(Gr),))), axis=0) # In case Data is not going to 30 AA, we add 0's.
Gr[:10] = np.zeros((10,))
r = np.arange(0, (rmax+0.1), 0.1)
# Normalise it to the data from the database
Gr /= np.max(Gr)
# Add experimental parameters to the Gr
Gr = np.expand_dims(np.concatenate((np.expand_dims(Qmin, axis=0), np.expand_dims(Qmax, axis=0), np.expand_dims(Qdamp, axis=0), Gr), axis=0), axis=0)
# Create a new figure object
fig, ax = plt.subplots()
# Plot the transformation to make sure everything is alright
ax.plot(PDF[:,0], PDF[:,1], label="Original Data")
ax.plot(r, Gr[0,3:], label="Gr ready for ML")
ax.legend()
ax.set_title("Original Data vs. normalised Data")
ax.set_xlabel("r (AA)")
ax.set_ylabel("Gr")
st.pyplot(fig)
return r, Gr
def POMPredicter(POMFinder, Gr, y_onehotenc_cat):
y_pred_proba = POMFinder.predict_proba(Gr);
y_pred_proba = y_pred_proba[:,1];
res = sorted(range(len(y_pred_proba)), key = lambda sub: y_pred_proba[sub]);
res.reverse();
st.markdown(f'<span style="font-size: 24px; color: green;">The 1st guess from the model is: <b>{str(y_onehotenc_cat.categories_[0][res[0]])[2:-2]+"cale.xyz"}</b> with a probability of {y_pred_proba[res[0]]:.2f} %</span> <hr/>',unsafe_allow_html=True)
st.markdown(f'<span style="font-size: 24px; color: green;">The 2nd guess from the model is: <b>{str(y_onehotenc_cat.categories_[0][res[1]])[2:-2]+"cale.xyz"}</b> with a probability of {y_pred_proba[res[1]]:.2f} %</span> <hr/>',unsafe_allow_html=True)
st.markdown(f'<span style="font-size: 24px; color: green;">The 3rd guess from the model is: <b>{str(y_onehotenc_cat.categories_[0][res[2]])[2:-2]+"cale.xyz"}</b> with a probability of {y_pred_proba[res[2]]:.2f} %</span> <hr/>',unsafe_allow_html=True)
st.markdown(f'<span style="font-size: 24px; color: green;">The 4th guess from the model is: <b>{str(y_onehotenc_cat.categories_[0][res[3]])[2:-2]+"cale.xyz"}</b> with a probability of {y_pred_proba[res[3]]:.2f} %</span> <hr/>',unsafe_allow_html=True)
st.markdown(f'<span style="font-size: 24px; color: green;">The 5th guess from the model is: <b>{str(y_onehotenc_cat.categories_[0][res[4]])[2:-2]+"cale.xyz"}</b> with a probability of {y_pred_proba[res[4]]:.2f} %</span> <hr/>',unsafe_allow_html=True)
return res, y_pred_proba
# Define a download button to download the file
def download_button(file_name, button_text):
with open(file_name, "rb") as f:
bytes = f.read()
st.download_button(
label=button_text,
data=bytes,
file_name=file_name,
mime="text/xyz",)
st.title('POMFinder')
st.write('Welcome to POMFinder which is a tree-based supervised learning algorithm that can predict the polyoxometalate cluster from a Pair Distribution Function.')
st.write('Upload a PDF to use POMFinder to predict the structure.')
# Define the file upload widget
pdf_file = st.file_uploader("Upload PDF file in .gr format", type=["gr"])
# Define the form to get the other parameters
Qmin = 0.7 #st.number_input("Qmin value of the experimental PDF", min_value=0.0, max_value=2.0, value=0.7)
Qmax = 30 #st.number_input("Qmax value of the experimental PDF", min_value=15.0, max_value=40.0, value=30.0)
Qdamp = 0.04 #st.number_input("Qdamp value of the experimental PDF", min_value=0.00, max_value=0.08, value=0.04)
nyquist = st.checkbox("Is the data nyquist sampled", value=False)
parser = argparse.ArgumentParser(prog='POMFinder', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args = parser.parse_args()
args.data = "uploaded_file.gr"
args.nyquist = nyquist
args.Qmin = Qmin
args.Qmax = Qmax
args.Qdamp = Qdamp
args.file_name = "POMFinder_results.txt"
if pdf_file is None:
st.warning("Please upload a PDF file.")
else:
# Get the contents of the file as bytes
file_bytes = pdf_file.read()
# Save the contents of the file to disk
with open("uploaded_file.gr", "wb") as f:
f.write(file_bytes)
#Predict with POMFinder
y, y_onehotenc_cat, y_onehotenc_values, POMFinder = get_POMFinder()
r, Gr = PDF_Preparation(args.data, args.Qmin, args.Qmax, args.Qdamp, rmax=10, nyquist=args.nyquist)
res, y_pred_proba = POMPredicter(POMFinder, Gr, y_onehotenc_cat);
# Download the structural database
#download_button("COD_ICSD_XYZs_POMs_unique99.zip", "Download structural database")
download_button("Backend/COD_ICSD_XYZs_POMs_unique99/"+str(y_onehotenc_cat.categories_[0][res[0]])[2:-2]+"cale.xyz", "Download top-5 predictions")
st.subheader('Cite')
st.write('If you use POMFinder, our code or results, please consider citing our paper. Thanks in advance!')
st.write('POMFinder: Identifying polyoxometalate cluster structures from pair distribution function data using explainable machine learning **2023** (https://chemrxiv.org/engage/chemrxiv/article-details/64e5fef7dd1a73847f5951b9)')
st.subheader('LICENSE')
st.write('This project is licensed under the Apache License Version 2.0, January 2004 - see the LICENSE file at https://github.com/AndySAnker/POMFinder/blob/master/LICENSE.txt for details.')
st.write("")
st.subheader('Github')
st.write('https://github.com/AndySAnker/POMFinder')
st.subheader('Questions')
st.write('andy@chem.ku.dk or etsk@chem.ku.dk')
|