|
import gradio as gr |
|
import os |
|
|
|
import copy |
|
import os |
|
import torch |
|
|
|
import time |
|
from argparse import ArgumentParser, Namespace, FileType |
|
from rdkit.Chem import RemoveHs |
|
from functools import partial |
|
import numpy as np |
|
import pandas as pd |
|
from rdkit import RDLogger |
|
from rdkit.Chem import MolFromSmiles, AddHs |
|
from torch_geometric.loader import DataLoader |
|
import yaml |
|
import sys |
|
import csv |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
print(torch.__version__) |
|
os.makedirs("data/esm2_output", exist_ok=True) |
|
os.makedirs("results", exist_ok=True) |
|
from datasets.process_mols import ( |
|
read_molecule, |
|
generate_conformer, |
|
write_mol_with_coords, |
|
) |
|
from datasets.pdbbind import PDBBind |
|
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule |
|
from utils.sampling import randomize_position, sampling |
|
from utils.utils import get_model |
|
from utils.visualise import PDBFile |
|
from tqdm import tqdm |
|
from datasets.esm_embedding_preparation import esm_embedding_prep |
|
import subprocess |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
with open(f"workdir/paper_score_model/model_parameters.yml") as f: |
|
score_model_args = Namespace(**yaml.full_load(f)) |
|
|
|
with open(f"workdir/paper_confidence_model/model_parameters.yml") as f: |
|
confidence_args = Namespace(**yaml.full_load(f)) |
|
|
|
import shutil |
|
|
|
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) |
|
|
|
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True) |
|
state_dict = torch.load( |
|
f"workdir/paper_score_model/best_ema_inference_epoch_model.pt", |
|
map_location=torch.device("cpu"), |
|
) |
|
model.load_state_dict(state_dict, strict=True) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
confidence_model = get_model( |
|
confidence_args, |
|
device, |
|
t_to_sigma=t_to_sigma, |
|
no_parallel=True, |
|
confidence_mode=True, |
|
) |
|
state_dict = torch.load( |
|
f"workdir/paper_confidence_model/best_model_epoch75.pt", |
|
map_location=torch.device("cpu"), |
|
) |
|
confidence_model.load_state_dict(state_dict, strict=True) |
|
confidence_model = confidence_model.to(device) |
|
confidence_model.eval() |
|
|
|
|
|
def get_pdb(pdb_code="", filepath=""): |
|
try: |
|
return filepath.name |
|
except AttributeError as e: |
|
if pdb_code is None or pdb_code == "": |
|
return None |
|
else: |
|
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") |
|
return f"{pdb_code}.pdb" |
|
|
|
|
|
def get_ligand(smiles="", filepath=""): |
|
if smiles is None or smiles == "": |
|
try: |
|
return filepath.name |
|
except AttributeError as e: |
|
return None |
|
else: |
|
return smiles |
|
|
|
|
|
def read_mol(molpath): |
|
with open(molpath, "r") as fp: |
|
lines = fp.readlines() |
|
mol = "" |
|
for l in lines: |
|
mol += l |
|
return mol |
|
|
|
|
|
def molecule(input_pdb, ligand_pdb, original_ligand): |
|
|
|
structure = read_mol(input_pdb) |
|
mol = read_mol(ligand_pdb) |
|
|
|
try: |
|
ligand = read_mol(original_ligand.name) |
|
_, ext = os.path.splitext(original_ligand.name) |
|
lig_str_1 = """let original_ligand = `""" + ligand + """`""" |
|
lig_str_2 = f""" |
|
viewer.addModel( original_ligand, "{ext[1:]}" ); |
|
viewer.getModel(2).setStyle({{stick:{{colorscheme:"greenCarbon"}}}});""" |
|
except AttributeError as e: |
|
ligand = None |
|
lig_str_1 = "" |
|
lig_str_2 = "" |
|
|
|
x = ( |
|
"""<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
|
<style> |
|
body{ |
|
font-family:sans-serif |
|
} |
|
.mol-container { |
|
width: 600px; |
|
height: 600px; |
|
position: relative; |
|
mx-auto:0 |
|
} |
|
.mol-container select{ |
|
background-image:None; |
|
} |
|
.green{ |
|
width:20px; |
|
height:20px; |
|
background-color:#33ff45; |
|
display:inline-block; |
|
} |
|
.magenta{ |
|
width:20px; |
|
height:20px; |
|
background-color:magenta; |
|
display:inline-block; |
|
} |
|
</style> |
|
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
|
</head> |
|
<body> |
|
<button id="startanimation">Replay diffusion process</button> |
|
<button id="togglesurface">Toggle surface representation</button> |
|
<div> |
|
<span class="green"></span> Uploaded ligand position |
|
<span class="magenta"></span> Predicted ligand position |
|
</div> |
|
<div id="container" class="mol-container"></div> |
|
|
|
<script> |
|
let ligand = `""" |
|
+ mol |
|
+ """` |
|
let structure = `""" |
|
+ structure |
|
+ """` |
|
""" |
|
+ lig_str_1 |
|
+ """ |
|
|
|
let viewer = null; |
|
let surface = false; |
|
let surf = null; |
|
$(document).ready(function () { |
|
let element = $("#container"); |
|
let config = { backgroundColor: "white" }; |
|
viewer = $3Dmol.createViewer(element, config); |
|
viewer.addModel( structure, "pdb" ); |
|
viewer.setStyle({}, {cartoon: {color: "gray"}}); |
|
viewer.zoomTo(); |
|
viewer.zoom(0.7); |
|
viewer.addModelsAsFrames(ligand, "pdb"); |
|
viewer.animate({loop: "forward",reps: 1}); |
|
|
|
viewer.getModel(1).setStyle({stick:{colorscheme:"magentaCarbon"}}); |
|
""" |
|
+ lig_str_2 |
|
+ """ |
|
viewer.render(); |
|
|
|
}) |
|
|
|
$("#startanimation").click(function() { |
|
viewer.animate({loop: "forward",reps: 1}); |
|
}); |
|
$("#togglesurface").click(function() { |
|
if (surface != true) { |
|
surf = viewer.addSurface($3Dmol.SurfaceType.VDW, { "opacity": 0.9, "color": "white" }, { model: 0 }); |
|
surface = true; |
|
} else { |
|
viewer.removeAllSurfaces() |
|
surface = false; |
|
} |
|
}); |
|
</script> |
|
</body></html>""" |
|
) |
|
|
|
return f"""<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
import sys |
|
|
|
|
|
def esm(protein_path, out_file): |
|
print("running esm") |
|
esm_embedding_prep(out_file, protein_path) |
|
|
|
os.environ["HOME"] = "esm/model_weights" |
|
subprocess.call( |
|
f"python esm/scripts/extract.py esm2_t33_650M_UR50D {out_file} data/esm2_output --repr_layers 33 --include per_tok", |
|
shell=True, |
|
env=os.environ, |
|
) |
|
|
|
|
|
def update(inp, file, ligand_inp, ligand_file, n_it): |
|
pdb_path = get_pdb(inp, file) |
|
ligand_path = get_ligand(ligand_inp, ligand_file) |
|
|
|
esm( |
|
pdb_path, |
|
f"data/{os.path.basename(pdb_path)}_prepared_for_esm.fasta", |
|
) |
|
tr_schedule = get_t_schedule(inference_steps=n_it) |
|
rot_schedule = tr_schedule |
|
tor_schedule = tr_schedule |
|
print("common t schedule", tr_schedule) |
|
( |
|
failures, |
|
skipped, |
|
confidences_list, |
|
names_list, |
|
run_times, |
|
min_self_distances_list, |
|
) = ( |
|
0, |
|
0, |
|
[], |
|
[], |
|
[], |
|
[], |
|
) |
|
N = 10 |
|
protein_path_list = [pdb_path] |
|
ligand_descriptions = [ligand_path] |
|
no_random = False |
|
ode = False |
|
no_final_step_noise = False |
|
out_dir = "results/" |
|
test_dataset = PDBBind( |
|
transform=None, |
|
root="", |
|
protein_path_list=protein_path_list, |
|
ligand_descriptions=ligand_descriptions, |
|
receptor_radius=score_model_args.receptor_radius, |
|
cache_path="data/cache", |
|
remove_hs=score_model_args.remove_hs, |
|
max_lig_size=None, |
|
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, |
|
matching=False, |
|
keep_original=False, |
|
popsize=score_model_args.matching_popsize, |
|
maxiter=score_model_args.matching_maxiter, |
|
all_atoms=score_model_args.all_atoms, |
|
atom_radius=score_model_args.atom_radius, |
|
atom_max_neighbors=score_model_args.atom_max_neighbors, |
|
esm_embeddings_path="data/esm2_output", |
|
require_ligand=True, |
|
num_workers=1, |
|
keep_local_structures=False, |
|
) |
|
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) |
|
confidence_test_dataset = PDBBind( |
|
transform=None, |
|
root="", |
|
protein_path_list=protein_path_list, |
|
ligand_descriptions=ligand_descriptions, |
|
receptor_radius=confidence_args.receptor_radius, |
|
cache_path="data/cache", |
|
remove_hs=confidence_args.remove_hs, |
|
max_lig_size=None, |
|
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, |
|
matching=False, |
|
keep_original=False, |
|
popsize=confidence_args.matching_popsize, |
|
maxiter=confidence_args.matching_maxiter, |
|
all_atoms=confidence_args.all_atoms, |
|
atom_radius=confidence_args.atom_radius, |
|
atom_max_neighbors=confidence_args.atom_max_neighbors, |
|
esm_embeddings_path="data/esm2_output", |
|
require_ligand=True, |
|
num_workers=1, |
|
) |
|
confidence_complex_dict = {d.name: d for d in confidence_test_dataset} |
|
for idx, orig_complex_graph in tqdm(enumerate(test_loader)): |
|
if ( |
|
confidence_model is not None |
|
and not ( |
|
confidence_args.use_original_model_cache |
|
or confidence_args.transfer_weights |
|
) |
|
and orig_complex_graph.name[0] not in confidence_complex_dict.keys() |
|
): |
|
skipped += 1 |
|
print( |
|
f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name[0]}. We are skipping this complex." |
|
) |
|
continue |
|
try: |
|
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)] |
|
randomize_position( |
|
data_list, |
|
score_model_args.no_torsion, |
|
no_random, |
|
score_model_args.tr_sigma_max, |
|
) |
|
pdb = None |
|
lig = orig_complex_graph.mol[0] |
|
visualization_list = [] |
|
for graph in data_list: |
|
pdb = PDBFile(lig) |
|
pdb.add(lig, 0, 0) |
|
pdb.add( |
|
( |
|
orig_complex_graph["ligand"].pos |
|
+ orig_complex_graph.original_center |
|
) |
|
.detach() |
|
.cpu(), |
|
1, |
|
0, |
|
) |
|
pdb.add( |
|
(graph["ligand"].pos + graph.original_center).detach().cpu(), |
|
part=1, |
|
order=1, |
|
) |
|
visualization_list.append(pdb) |
|
|
|
start_time = time.time() |
|
if confidence_model is not None and not ( |
|
confidence_args.use_original_model_cache |
|
or confidence_args.transfer_weights |
|
): |
|
confidence_data_list = [ |
|
copy.deepcopy(confidence_complex_dict[orig_complex_graph.name[0]]) |
|
for _ in range(N) |
|
] |
|
else: |
|
confidence_data_list = None |
|
|
|
data_list, confidence = sampling( |
|
data_list=data_list, |
|
model=model, |
|
inference_steps=n_it, |
|
tr_schedule=tr_schedule, |
|
rot_schedule=rot_schedule, |
|
tor_schedule=tor_schedule, |
|
device=device, |
|
t_to_sigma=t_to_sigma, |
|
model_args=score_model_args, |
|
no_random=no_random, |
|
ode=ode, |
|
visualization_list=visualization_list, |
|
confidence_model=confidence_model, |
|
confidence_data_list=confidence_data_list, |
|
confidence_model_args=confidence_args, |
|
batch_size=1, |
|
no_final_step_noise=no_final_step_noise, |
|
) |
|
ligand_pos = np.asarray( |
|
[ |
|
complex_graph["ligand"].pos.cpu().numpy() |
|
+ orig_complex_graph.original_center.cpu().numpy() |
|
for complex_graph in data_list |
|
] |
|
) |
|
run_times.append(time.time() - start_time) |
|
|
|
if confidence is not None and isinstance( |
|
confidence_args.rmsd_classification_cutoff, list |
|
): |
|
confidence = confidence[:, 0] |
|
if confidence is not None: |
|
confidence = confidence.cpu().numpy() |
|
re_order = np.argsort(confidence)[::-1] |
|
confidence = confidence[re_order] |
|
confidences_list.append(confidence) |
|
ligand_pos = ligand_pos[re_order] |
|
write_dir = ( |
|
f'{out_dir}/index{idx}_{data_list[0]["name"][0].replace("/","-")}' |
|
) |
|
os.makedirs(write_dir, exist_ok=True) |
|
confidences = [] |
|
for rank, pos in enumerate(ligand_pos): |
|
mol_pred = copy.deepcopy(lig) |
|
if score_model_args.remove_hs: |
|
mol_pred = RemoveHs(mol_pred) |
|
if rank == 0: |
|
write_mol_with_coords( |
|
mol_pred, pos, os.path.join(write_dir, f"rank{rank+1}.sdf") |
|
) |
|
confidences.append(confidence[rank]) |
|
write_mol_with_coords( |
|
mol_pred, |
|
pos, |
|
os.path.join( |
|
write_dir, f"rank{rank+1}_confidence{confidence[rank]:.2f}.sdf" |
|
), |
|
) |
|
self_distances = np.linalg.norm( |
|
ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1 |
|
) |
|
self_distances = np.where( |
|
np.eye(self_distances.shape[2]), np.inf, self_distances |
|
) |
|
min_self_distances_list.append(np.min(self_distances, axis=(1, 2))) |
|
|
|
filenames = [] |
|
if confidence is not None: |
|
for rank, batch_idx in enumerate(re_order): |
|
visualization_list[batch_idx].write( |
|
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
|
) |
|
filenames.append( |
|
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
|
) |
|
else: |
|
for rank, batch_idx in enumerate(ligand_pos): |
|
visualization_list[batch_idx].write( |
|
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
|
) |
|
filenames.append( |
|
os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
|
) |
|
names_list.append(orig_complex_graph.name[0]) |
|
except Exception as e: |
|
print("Failed on", orig_complex_graph["name"], e) |
|
failures += 1 |
|
return None |
|
|
|
zippath = shutil.make_archive( |
|
os.path.join("results", os.path.basename(pdb_path)), "zip", write_dir |
|
) |
|
print("Zipped outputs to", zippath) |
|
labels = [ |
|
f"rank {i+1}, confidence {confidences[i]:.2f}" for i in range(len(filenames)) |
|
] |
|
|
|
torch.cuda.empty_cache() |
|
return ( |
|
molecule(pdb_path, filenames[0], ligand_file), |
|
gr.Dropdown.update(choices=labels, value=labels[0]), |
|
filenames, |
|
pdb_path, |
|
zippath, |
|
) |
|
|
|
|
|
def updateView(out, filenames, pdb, ligand_file): |
|
print("updating view") |
|
i = out |
|
print(i) |
|
i = int(i.split(",")[0].replace("rank", "")) - 1 |
|
return molecule(pdb, filenames[i], ligand_file) |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown("# DiffDock") |
|
gr.Markdown( |
|
">**DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking**, Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi, arXiv:2210.01776 [GitHub](https://github.com/gcorso/diffdock)" |
|
) |
|
gr.Markdown("") |
|
with gr.Box(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Protein") |
|
inp = gr.Textbox( |
|
placeholder="PDB Code or upload file below", label="Input structure" |
|
) |
|
file = gr.File(file_count="single", label="Input PDB") |
|
with gr.Column(): |
|
gr.Markdown("## Ligand") |
|
ligand_inp = gr.Textbox( |
|
placeholder="Provide SMILES input or upload mol2/sdf file below", |
|
label="SMILES string", |
|
) |
|
ligand_file = gr.File(file_count="single", label="Input Ligand") |
|
n_it = gr.Slider( |
|
minimum=10, maximum=40, label="Number of inference steps", step=1 |
|
) |
|
|
|
btn = gr.Button("Run predictions") |
|
|
|
gr.Markdown("## Output") |
|
pdb = gr.Variable() |
|
filenames = gr.Variable() |
|
out = gr.Dropdown(interactive=True, label="Ranked samples") |
|
mol = gr.HTML() |
|
output_file = gr.File(file_count="single", label="Output files") |
|
gr.Examples( |
|
[ |
|
[ |
|
"6w70", |
|
"examples/6w70.pdb", |
|
"COc1ccc(cc1)n2c3c(c(n2)C(=O)N)CCN(C3=O)c4ccc(cc4)N5CCCCC5=O", |
|
"examples/6w70_ligand.sdf", |
|
10, |
|
], |
|
[ |
|
"6moa", |
|
"examples/6moa_protein_processed.pdb", |
|
"", |
|
"examples/6moa_ligand.sdf", |
|
10, |
|
], |
|
[ |
|
"", |
|
"examples/6o5u_protein_processed.pdb", |
|
"", |
|
"examples/6o5u_ligand.sdf", |
|
10, |
|
], |
|
[ |
|
"", |
|
"examples/6o5u_protein_processed.pdb", |
|
"[NH3+]C[C@H]1O[C@H](O[C@@H]2[C@@H]([NH3+])C[C@H]([C@@H]([C@H]2O)O[C@H]2O[C@H](CO)[C@H]([C@@H]([C@H]2O)[NH3+])O)[NH3+])[C@@H]([C@H]([C@@H]1O)O)O", |
|
"examples/6o5u_ligand.sdf", |
|
10, |
|
], |
|
[ |
|
"", |
|
"examples/6o5u_protein_processed.pdb", |
|
"", |
|
"examples/6o5u_ligand.sdf", |
|
10, |
|
], |
|
[ |
|
"", |
|
"examples/6ahs_protein_processed.pdb", |
|
"", |
|
"examples/6ahs_ligand.sdf", |
|
10, |
|
], |
|
], |
|
[inp, file, ligand_inp, ligand_file, n_it], |
|
[mol, out, filenames, pdb, output_file], |
|
|
|
|
|
) |
|
btn.click( |
|
fn=update, |
|
inputs=[inp, file, ligand_inp, ligand_file, n_it], |
|
outputs=[mol, out, filenames, pdb, output_file], |
|
) |
|
out.change(fn=updateView, inputs=[out, filenames, pdb, ligand_file], outputs=mol) |
|
demo.launch() |
|
|