import itertools as it import os import joblib import numpy as np import pandas as pd import pkg_resources import streamlit as st from b3clf.descriptor_padel import compute_descriptors from b3clf.geometry_opt import geometry_optimize from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors @st.cache_resource() def load_all_models(): """Get b3clf fitted classifier""" clf_list = ["dtree", "knn", "logreg", "xgb"] sampling_list = [ "borderline_SMOTE", "classic_ADASYN", "classic_RandUndersampling", "classic_SMOTE", "kmeans_SMOTE", "common", ] model_dict = {} package_name = "b3clf" for clf_str, sampling_str in it.product(clf_list, sampling_list): # joblib_fpath = os.path.join( # dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str)) # pred_model = joblib.load(joblib_fpath) joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib" with pkg_resources.resource_stream(package_name, joblib_path_str) as f: pred_model = joblib.load(f) model_dict[clf_str + "_" + sampling_str] = pred_model return model_dict @st.cache_resource def predict_permeability( clf_str, sampling_str, _models_dict, mol_features, info_df, threshold="none" ): """Compute permeability prediction for given feature data.""" # load the model # pred_model = load_all_models()[clf_str + "_" + sampling_str] pred_model = _models_dict[clf_str + "_" + sampling_str] # load the threshold data package_name = "b3clf" with pkg_resources.resource_stream(package_name, "data/B3clf_thresholds.xlsx") as f: df_thres = pd.read_excel(f, index_col=0, engine="openpyxl") # default threshold is 0.5 label_pool = np.zeros(mol_features.shape[0], dtype=int) if type(mol_features) == pd.DataFrame: if mol_features.index.tolist() != info_df.index.tolist(): raise ValueError("Features_df and Info_df do not have the same index.") # get predicted probabilities info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba( mol_features )[:, 1] # get predicted label from probability using the threshold mask = np.greater_equal( info_df["B3clf_predicted_probability"].to_numpy(), # df_thres.loc[clf_str + "-" + sampling_str, threshold]) df_thres.loc["xgb-classic_ADASYN", threshold], ) label_pool[mask] = 1 # save the predicted labels info_df["B3clf_predicted_label"] = label_pool info_df.reset_index(inplace=True) return info_df @st.cache_resource def generate_predictions( input_fname: str = None, sep: str = "\s+|\t+", clf: str = "xgb", _models_dict: dict = None, keep_sdf: str = "no", sampling: str = "classic_ADASYN", time_per_mol: int = 120, mol_features: pd.DataFrame = None, info_df: pd.DataFrame = None, ): """ Generate predictions for a given input file. """ if mol_features is None and info_df is None: # mol_tag = os.path.splitext(uploaded_file.name)[0] # uploaded_file = uploaded_file.read().decode("utf-8") mol_tag = os.path.basename(input_fname).split(".")[0] internal_sdf = f"{mol_tag}_optimized_3d.sdf" # Geometry optimization # Input: # * Either an SDF file with molecular geometries or a text file with SMILES strings geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep) df_features = compute_descriptors( sdf_file=internal_sdf, excel_out=None, output_csv=None, timeout=None, time_per_molecule=time_per_mol, ) # Get computed descriptors mol_features, info_df = get_descriptors(df=df_features) # Select descriptors mol_features = select_descriptors(df=mol_features) # Scale descriptors mol_features.iloc[:, :] = scale_descriptors(df=mol_features) # this is problematic for using the same file for calculation if os.path.exists(internal_sdf) and keep_sdf == "no": os.remove(internal_sdf) # Get classifier # clf = get_clf(clf_str=clf, sampling_str=sampling) # Get classifier result_df = predict_permeability( clf_str=clf, sampling_str=sampling, _models_dict=_models_dict, mol_features=mol_features, info_df=info_df, threshold="none", ) # Get classifier display_cols = [ "ID", "SMILES", "B3clf_predicted_probability", "B3clf_predicted_label", ] result_df = result_df[ [col for col in result_df.columns.to_list() if col in display_cols] ] return mol_features, info_df, result_df