mbp / app.py
jiaxianustc's picture
Initial commit
865e35e
import gradio as gr
import argparse
import os
from UltraFlow.models.sbap import *
from UltraFlow import commons
import warnings
warnings.filterwarnings("ignore")
model_dir = './workdir/gradio/'
checkpoint = 'checkpointbest_valid_1.ckp'
total_num = 0
def get_config(model_dir):
# get config
config = commons.get_config_easydict(os.path.join(model_dir, 'affinity_default.yaml'))
# get device
# config.device = commons.get_device(config.train.gpus, config.train.gpu_memory_need)
config.device = 'cpu'
# set random seed
commons.set_seed(config.seed)
return config
def load_graph_dim(lig_graph, prot_graph, model_config):
lig_node_dim = lig_graph.ndata['h'].shape[1]
lig_edge_dim = lig_graph.edata['e'].shape[1]
if model_config.data.add_chemical_bond_feats:
lig_edge_dim += lig_graph.edata['bond_type'].shape[1]
pro_node_dim = prot_graph.ndata['h'].shape[1]
pro_edge_dim = prot_graph.edata['e'].shape[1]
inter_edge_dim = 15
if model_config.data.use_mean_node_features:
lig_node_dim += 5
pro_node_dim += 5
return lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim
def trans_device(data, device):
return [x if isinstance(x, list) else x.to(device) for x in data]
def get_data(model_config, ligand_path, protein_path):
molecular_representation = commons.read_molecules_inference(ligand_path, protein_path,
model_config.data.prot_graph_type,
model_config.data.chaincut)
lig_coords, lig_features, lig_edges, lig_node_type, \
prot_coords, prot_features, prot_edges, prot_node_type, \
sec_features, alpha_c_coords, c_coords, n_coords, ca_res_number_valid, chain_index_valid = molecular_representation
lig_graph = commons.get_lig_graph_equibind(lig_coords, lig_features, lig_edges, lig_node_type,
max_neighbors=model_config.data.lig_max_neighbors,
cutoff=model_config.data.ligcut)
prot_graph = commons.get_prot_alpha_c_graph_equibind(prot_coords, prot_features, prot_node_type,
sec_features, alpha_c_coords, c_coords, n_coords,
max_neighbor=model_config.data.prot_max_neighbors,
cutoff=model_config.data.protcut)
prot_graph.ndata['res_number'] = torch.tensor(ca_res_number_valid)
prot_graph.chain_index = chain_index_valid
inter_graph = commons.get_interact_graph_knn_v2(lig_coords, prot_coords,
max_neighbor=model_config.data.inter_max_neighbors,
min_neighbor=model_config.data.inter_min_neighbors,
cutoff=model_config.data.intercut)
# set feats dim
lig_node_dim, lig_edge_dim, pro_node_dim, pro_edge_dim, inter_edge_dim = load_graph_dim(lig_graph, prot_graph, model_config)
model_config.model.lig_node_dim, model_config.model.lig_edge_dim = lig_node_dim, lig_edge_dim
model_config.model.pro_node_dim, model_config.model.pro_edge_dim = pro_node_dim, pro_edge_dim
model_config.model.inter_edge_dim = inter_edge_dim
if model_config.data.add_chemical_bond_feats:
lig_graph.edata['e'] = torch.cat([lig_graph.edata['e'], lig_graph.edata['bond_type']], dim=-1)
if model_config.data.use_mean_node_features:
lig_graph.ndata['h'] = torch.cat([lig_graph.ndata['h'], lig_graph.ndata['mu_r_norm']], dim=-1)
prot_graph.ndata['h'] = torch.cat([prot_graph.ndata['h'], prot_graph.ndata['mu_r_norm']], dim=-1)
label = torch.tensor(-100).unsqueeze(dim=-1)
item = [0]
assay_des = torch.zeros(0)
IC50_f, K_f = [True], [True]
data = (lig_graph, prot_graph, inter_graph, label, item, assay_des.unsqueeze(dim=0), IC50_f, K_f)
return trans_device(data, model_config.device)
def get_models(model_config, model_dir, checkpoint):
if model_config.train.multi_task:
model = globals()[model_config.model.model_type + '_MTL'](model_config).to(model_config.device)
else:
model = globals()[model_config.model.model_type](model_config).to(model_config.device)
checkpoint_path = os.path.join(model_dir, checkpoint)
print("Load checkpoint from %s" % checkpoint_path)
state = torch.load(checkpoint_path, map_location=model_config.device)
model.load_state_dict(state["model"])
model = model.eval()
return model
def mbp_scoring(ligand_path, protein_path):
data_example = get_data(model_config, ligand_path, protein_path)
_, (affinity_pred_IC50, affinity_pred_K), _ = model(data_example, ASRP=False)
return affinity_pred_IC50.item(), affinity_pred_K.item()
def test(ligand, protein):
global total_num
total_num = total_num + 1
print(f'total num: {total_num}')
try:
IC50, K = mbp_scoring(ligand.name, protein.name)
print(f'ligand file name: {os.path.basename(ligand.name)},'
f' protein file name: {os.path.basename(protein.name)},'
f' IC50: {IC50}, K: {K}')
return '{:.2f}'.format(IC50), '{:.2f}'.format(K)
except Exception as e:
# print(e)
return e, e
with gr.Blocks() as demo:
gr.Markdown(
"""
# Multi-task Bioassay Pre-training for Protein-Ligand Binding Affinity Prediction
## Welcome to the MBP demo !
- Feel free to upload your own examples. Please upload an individual ligand 3D file and an individual protein 3D file each time.
- If you encounter any issues, please reach out to jiaxianyan@mail.ustc.edu.cn.
- All codes and data are available on the online platform https://github.com/jiaxianyan/MBP.
""")
with gr.Row():
ligand = gr.File(label="Ligand 3D file. MBP utilizes openbabel to process ligand files and supports all file types that openbabel can read.")
protein = gr.File(label="Protein 3D file. Currently, MBP only supports the pdb file type for protein files.")
IC50 = gr.Textbox(label="Predicted IC50 Value")
K = gr.Textbox(label="Predicted K Value")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=test, inputs=[ligand, protein], outputs=[IC50, K], api_name="MBP_Scoring")
gr.Markdown("## Input Examples")
gr.Examples(
examples=[['./workdir/gradio/1a0q_ligand.sdf','./workdir/gradio/1a0q_protein.pdb']],
inputs=[ligand, protein],
# outputs=[IC50, K],
fn=test,
cache_examples=False,
)
model_config = get_config(model_dir)
data_example = get_data(model_config, './workdir/gradio/1a0q_ligand.sdf', './workdir/gradio/1a0q_protein.pdb')
model = get_models(model_config, model_dir, checkpoint)
demo.launch(share=False)