Spaces:
Configuration error
Configuration error
from sklearn.neighbors import KNeighborsClassifier | |
import streamlit as st | |
from rdkit.Chem import MACCSkeys | |
from rdkit import Chem | |
import numpy as np | |
import pandas as pd | |
import xgboost as xgb | |
from sklearn.svm import SVC | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import train_test_split, cross_val_score | |
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_auc_score | |
import pickle | |
global header | |
model_path = 'model/' | |
def load_tpr_fpr(ml, enzyme): | |
tpr_file = 'AUC/' + ml + '_' + enzyme + '_tpr.pickle' | |
fpr_file = 'AUC/' + ml + '_' + enzyme + '_fpr.pickle' | |
with open(tpr_file, 'rb') as f: | |
tpr = pickle.load(f) | |
with open(fpr_file, 'rb') as f: | |
fpr = pickle.load(f) | |
return tpr, fpr | |
def smile_list_to_MACCS(smi_list): | |
MACCS_list = [] | |
for smi in smi_list: | |
mol = Chem.MolFromSmiles(smi) | |
maccs = list(MACCSkeys.GenMACCSKeys(mol).ToBitString()) | |
MACCS_list.append(maccs) | |
return MACCS_list | |
st.write(""" | |
# Area Under the Curve Ploting | |
""") | |
st.sidebar.header('User Input Parameters') | |
def user_input_features(): | |
# name = st.text_input('compound name', 'Fedratinib') | |
# if name == None: | |
# name = 'test' | |
# smi = st.text_input('compound SMILES', 'CC1=CN=C(N=C1NC2=CC(=CC=C2)S(=O)(=O)NC(C)(C)C)NC3=CC=C(C=C3)OCCN4CCCC4') | |
# if name == None and smi == None: | |
# name ='Fedratinib' | |
# smi = 'CC1=CN=C(N=C1NC2=CC(=CC=C2)S(=O)(=O)NC(C)(C)C)NC3=CC=C(C=C3)OCCN4CCCC4' | |
# enzyme = st.multiselect( | |
# 'Choose JAK: ', | |
# ['JAK1', 'JAK2', 'JAK3', 'TYK2']) | |
# if enzyme == None: | |
# enzyme = 'JAK1' | |
st.write('Select JAK kinase: ') | |
JAK1 = st.checkbox('JAK1') | |
JAK2 = st.checkbox('JAK2') | |
JAK3 = st.checkbox('JAK3') | |
TYK2 = st.checkbox('TYK2') | |
all_enzyme = st.checkbox('Select all JAKs') | |
enzyme = [] | |
if JAK1 == True: | |
enzyme.append('JAK1') | |
if JAK2 == True: | |
enzyme.append('JAK2') | |
if JAK3 == True: | |
enzyme.append('JAK3') | |
if TYK2 == True: | |
enzyme.append('TYK2') | |
if all_enzyme == True: | |
enzyme = ['JAK1', 'JAK2', 'JAK3', 'TYK2'] | |
# model = st.multiselect( | |
# 'Choose model: ', | |
# ['knn','SVM_linear', 'SVM_poly', 'SVM_rbf', 'SVM_sigmoid', 'XGBoost']) | |
model = [] | |
st.write('Select model: ') | |
knn = st.checkbox('KNN') | |
SVM_linear = st.checkbox('SVM_linear') | |
SVM_poly = st.checkbox('SVM_poly') | |
SVM_rbf = st.checkbox('SVM_rbf') | |
SVM_sigmoid = st.checkbox('SVM_sigmoid') | |
RF = st.checkbox('RF') | |
XGBoost = st.checkbox('XGBoost') | |
CNN = st.checkbox('CNN') | |
GVAE = st.checkbox('GraphVAE') | |
chemBERTa = st.checkbox('chembert') | |
all_model = st.checkbox('Select all models') | |
if knn == True: | |
model.append('knn') | |
if SVM_linear == True: | |
model.append('SVM_linear') | |
if SVM_poly == True: | |
model.append('SVM_poly') | |
if SVM_rbf == True: | |
model.append('SVM_rbf') | |
if SVM_sigmoid == True: | |
model.append('SVM_sigmoid') | |
if RF == True: | |
model.append('RF') | |
if XGBoost == True: | |
model.append('XGBoost') | |
if CNN == True: | |
model.append('CNN') | |
if GVAE == True: | |
model.append('GVAE') | |
if chemBERTa == True: | |
model.append('chembert') | |
if all_model == True: | |
model = ['knn', 'SVM_linear', 'SVM_poly', 'SVM_rbf', 'SVM_sigmoid', 'RF', 'XGBoost', 'CNN', 'GVAE', 'chembert'] | |
return enzyme, model | |
with st.sidebar: | |
enzymes, model_chosen = user_input_features() | |
st.subheader('User Input parameters:') | |
# st.write('Current compound: ', name) | |
# st.write('Current compound SMILE: ', smi) | |
st.write('Selected JAK:', enzymes) | |
st.write('Selected model: ', model_chosen) | |
if st.button('Start Plot AUC'): | |
if model_chosen==[]: | |
st.write('Did not choose model!') | |
if enzymes==[]: | |
st.write('Did not choose JAK kinase!') | |
elif model_chosen != [] and enzymes != []: | |
for enzyme in enzymes: | |
title = enzyme + ' Receiver Operating Characteristic Curve' | |
models = model_chosen | |
fig, ax = plt.subplots(figsize=(10,10)) | |
for ml in models: | |
tpr, fpr = load_tpr_fpr(ml, enzyme) | |
ax.plot(fpr, tpr, label=ml) | |
ax.plot(np.linspace(0, 1, 100), | |
np.linspace(0, 1, 100), | |
label='baseline', | |
linestyle='--') | |
plt.title(title, fontsize=18) | |
plt.ylabel('TPR', fontsize=16) | |
plt.xlabel('FPR', fontsize=16) | |
plt.legend(fontsize=12) | |
# plt.savefig('figures/'+enzyme+'.png') | |
st.pyplot(fig) | |