b3clf / utils.py
legend1234's picture
Fix data shape mismatch
46411f2
import itertools as it
import os
import joblib
import numpy as np
import pandas as pd
import pkg_resources
import streamlit as st
from b3clf.descriptor_padel import compute_descriptors
from b3clf.geometry_opt import geometry_optimize
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
@st.cache_resource()
def load_all_models():
"""Get b3clf fitted classifier"""
clf_list = ["dtree", "knn", "logreg", "xgb"]
sampling_list = [
"borderline_SMOTE",
"classic_ADASYN",
"classic_RandUndersampling",
"classic_SMOTE",
"kmeans_SMOTE",
"common",
]
model_dict = {}
package_name = "b3clf"
for clf_str, sampling_str in it.product(clf_list, sampling_list):
# joblib_fpath = os.path.join(
# dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
# pred_model = joblib.load(joblib_fpath)
joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
pred_model = joblib.load(f)
model_dict[clf_str + "_" + sampling_str] = pred_model
return model_dict
@st.cache_resource
def predict_permeability(
clf_str, sampling_str, _models_dict, mol_features, info_df, threshold="none"
):
"""Compute permeability prediction for given feature data."""
# load the model
# pred_model = load_all_models()[clf_str + "_" + sampling_str]
pred_model = _models_dict[clf_str + "_" + sampling_str]
# load the threshold data
package_name = "b3clf"
with pkg_resources.resource_stream(package_name, "data/B3clf_thresholds.xlsx") as f:
df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")
# default threshold is 0.5
label_pool = np.zeros(mol_features.shape[0], dtype=int)
if type(mol_features) == pd.DataFrame:
if mol_features.index.tolist() != info_df.index.tolist():
raise ValueError("mol_features and Info_df do not have the same index.")
# get predicted probabilities
info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(
mol_features
)[:, 1]
# get predicted label from probability using the threshold
mask = np.greater_equal(
info_df["B3clf_predicted_probability"].to_numpy(),
# df_thres.loc[clf_str + "-" + sampling_str, threshold])
df_thres.loc["xgb-classic_ADASYN", threshold],
)
label_pool[mask] = 1
# save the predicted labels
info_df["B3clf_predicted_label"] = label_pool
info_df.reset_index(inplace=True)
return info_df
@st.cache_resource
def generate_predictions(
input_fname: str = None,
sep: str = "\s+|\t+",
clf: str = "xgb",
_models_dict: dict = None,
keep_sdf: str = "no",
sampling: str = "classic_ADASYN",
time_per_mol: int = 120,
mol_features: pd.DataFrame = None,
info_df: pd.DataFrame = None,
):
"""
Generate predictions for a given input file.
"""
try:
if mol_features is None and info_df is None:
if input_fname is None:
raise ValueError("Either input_fname or mol_features/info_df must be provided")
mol_tag = os.path.basename(input_fname).split(".")[0]
file_ext = os.path.splitext(input_fname)[1].lower()
internal_sdf = f"{mol_tag}_optimized_3d.sdf"
try:
# Handle different file types
if file_ext == '.csv':
sep = ','
elif file_ext == '.txt' or file_ext == '.smi':
sep = '\s+|\t+'
elif file_ext != '.sdf':
raise ValueError(f"Unsupported file type: {file_ext}")
# Geometry optimization
geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep)
# Compute descriptors with timeout handling
df_features = compute_descriptors(
sdf_file=internal_sdf,
excel_out=None,
output_csv=None,
timeout=time_per_mol * 2, # Double the per-molecule time for total timeout
time_per_molecule=time_per_mol,
)
# Get computed descriptors
mol_features, info_df = get_descriptors(df=df_features)
# Select descriptors
mol_features = select_descriptors(df=mol_features)
# Clean data before scaling - replace empty strings with NaN and drop rows with NaN values
mol_features = mol_features.replace('', np.nan)
mol_features = mol_features.apply(pd.to_numeric, errors='coerce')
if mol_features.isnull().any().any():
st.warning("Some descriptors contained invalid values and were removed")
# Get indices of valid rows
valid_indices = ~mol_features.isnull().any(axis=1)
# Update both dataframes to keep only valid rows
mol_features = mol_features[valid_indices]
info_df = info_df[valid_indices]
if len(mol_features) == 0:
raise ValueError("No valid data remains after cleaning")
# Scale descriptors
mol_features.iloc[:, :] = scale_descriptors(df=mol_features)
finally:
# Clean up temporary files
if os.path.exists(internal_sdf) and keep_sdf == "no":
try:
os.remove(internal_sdf)
except:
pass
# Get predictions
result_df = predict_permeability(
clf_str=clf,
sampling_str=sampling,
_models_dict=_models_dict,
mol_features=mol_features,
info_df=info_df,
threshold="none",
)
# Select display columns
display_cols = [
"ID",
"SMILES",
"B3clf_predicted_probability",
"B3clf_predicted_label",
]
result_df = result_df[
[col for col in result_df.columns.to_list() if col in display_cols]
]
return mol_features, info_df, result_df
except Exception as e:
import traceback
st.error(f"Error in generate_predictions: {str(e)}\n{traceback.format_exc()}")
raise