# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as # initial pose import os import shutil import subprocess import sys import time from argparse import ArgumentParser, FileType from datetime import datetime import numpy as np import pandas as pd from biopandas.pdb import PandasPdb from rdkit import Chem from rdkit.Chem import AllChem, MolToPDBFile from scipy.spatial.distance import cdist from datasets.pdbbind import read_mol from utils.utils import read_strings_from_txt parser = ArgumentParser() parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='') parser.add_argument('--file_suffix', type=str, default='_baseline_ligand', help='Path to folder with trained model and hyperparameters') parser.add_argument('--results_path', type=str, default='results/gnina_predictions', help='') parser.add_argument('--complex_names_path', type=str, default='data/splits/timesplit_test', help='') parser.add_argument('--seed_molecules_path', type=str, default=None, help='Use the molecules at seed molecule path as initialization and only search around them') parser.add_argument('--seed_molecule_filename', type=str, default='equibind_corrected.sdf', help='Use the molecules at seed molecule path as initialization and only search around them') parser.add_argument('--smina', action='store_true', default=False, help='') parser.add_argument('--no_gpu', action='store_true', default=False, help='') parser.add_argument('--exhaustiveness', type=int, default=8, help='') parser.add_argument('--num_cpu', type=int, default=16, help='') parser.add_argument('--pocket_mode', action='store_true', default=False, help='') parser.add_argument('--pocket_cutoff', type=int, default=5, help='') parser.add_argument('--num_modes', type=int, default=10, help='') parser.add_argument('--autobox_add', type=int, default=4, help='') parser.add_argument('--use_p2rank_pocket', action='store_true', default=False, help='') parser.add_argument('--skip_p2rank', action='store_true', default=False, help='') parser.add_argument('--prank_path', type=str, default='/Users/hstark/projects/p2rank_2.3/prank', help='') parser.add_argument('--skip_existing', action='store_true', default=False, help='') args = parser.parse_args() class Logger(object): def __init__(self, logpath, syspart=sys.stdout): self.terminal = syspart self.log = open(logpath, "a") def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() def flush(self): # this flush method is needed for python 3 compatibility. # this handles the flush command by doing nothing. # you might want to specify some extra behavior here. pass def log(*args): print(f'[{datetime.now()}]', *args) # parameters names = read_strings_from_txt(args.complex_names_path) if os.path.exists(args.results_path) and not args.skip_existing: shutil.rmtree(args.results_path) os.makedirs(args.results_path, exist_ok=True) sys.stdout = Logger(logpath=f'{args.results_path}/gnina.log', syspart=sys.stdout) sys.stderr = Logger(logpath=f'{args.results_path}/error.log', syspart=sys.stderr) p2rank_cache_path = "results/.p2rank_cache" if args.use_p2rank_pocket and not args.skip_p2rank: os.makedirs(p2rank_cache_path, exist_ok=True) pdb_files_cache = os.path.join(p2rank_cache_path,'pdb_files') os.makedirs(pdb_files_cache, exist_ok=True) with open(f"{p2rank_cache_path}/pdb_list_p2rank.txt", "w") as out: for name in names: shutil.copy(os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb'), f'{pdb_files_cache}/{name}_protein_processed.pdb') out.write(os.path.join('pdb_files', f'{name}_protein_processed.pdb\n')) cmd = f"bash {args.prank_path} predict {p2rank_cache_path}/pdb_list_p2rank.txt -o {p2rank_cache_path}/p2rank_output -threads 4" os.system(cmd) all_times = [] start_time = time.time() for i, name in enumerate(names): os.makedirs(os.path.join(args.results_path, name), exist_ok=True) log('\n') log(f'complex {i} of {len(names)}') # call gnina to find binding pose rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb') prediction_output_name = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.pdb') log_path = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.log') if args.seed_molecules_path is not None: seed_mol_path = os.path.join(args.seed_molecules_path, name, f'{args.seed_molecule_filename}') if args.skip_existing and os.path.exists(prediction_output_name): continue if args.pocket_mode: mol = read_mol(args.data_dir, name, remove_hs=False) rec = PandasPdb().read_pdb(rec_path) rec_df = rec.get(s='c-alpha') rec_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32) lig_pos = mol.GetConformer().GetPositions() d = cdist(rec_pos, lig_pos) label = np.any(d < args.pocket_cutoff, axis=1) if np.any(label): center_pocket = rec_pos[label].mean(axis=0) else: print("No pocket residue below minimum distance ", args.pocket_cutoff, "taking closest at", np.min(d)) center_pocket = rec_pos[np.argmin(np.min(d, axis=1)[0])] radius_pocket = np.max(np.linalg.norm(lig_pos - center_pocket[None, :], axis=1)) diameter_pocket = radius_pocket * 2 center_x = center_pocket[0] size_x = diameter_pocket + 8 center_y = center_pocket[1] size_y = diameter_pocket + 8 center_z = center_pocket[2] size_z = diameter_pocket + 8 mol_rdkit = read_mol(args.data_dir, name, remove_hs=False) single_time = time.time() mol_rdkit.RemoveAllConformers() ps = AllChem.ETKDGv2() id = AllChem.EmbedMolecule(mol_rdkit, ps) if id == -1: print('rdkit pos could not be generated without using random pos. using random pos now.') ps.useRandomCoords = True AllChem.EmbedMolecule(mol_rdkit, ps) AllChem.MMFFOptimizeMolecule(mol_rdkit, confId=0) rdkit_mol_path = os.path.join(args.data_dir, name, f'{name}_rdkit_ligand.pdb') MolToPDBFile(mol_rdkit, rdkit_mol_path) fallback_without_p2rank = False if args.use_p2rank_pocket: df = pd.read_csv(f'{p2rank_cache_path}/p2rank_output/{name}_protein_processed.pdb_predictions.csv') rdkit_lig_pos = mol_rdkit.GetConformer().GetPositions() diameter_pocket = np.max(cdist(rdkit_lig_pos, rdkit_lig_pos)) size_x = diameter_pocket + args.autobox_add * 2 size_y = diameter_pocket + args.autobox_add * 2 size_z = diameter_pocket + args.autobox_add * 2 if df.empty: fallback_without_p2rank = True else: center_x = df.iloc[0][' center_x'] center_y = df.iloc[0][' center_y'] center_z = df.iloc[0][' center_z'] log(f'processing {rec_path}') if not args.pocket_mode and not args.use_p2rank_pocket or fallback_without_p2rank: return_code = subprocess.run( f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --autobox_ligand {rec_path if args.seed_molecules_path is None else seed_mol_path} --autobox_add {args.autobox_add} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''}", shell=True) else: return_code = subprocess.run( f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''} --center_x {center_x} --center_y {center_y} --center_z {center_z} --size_x {size_x} --size_y {size_y} --size_z {size_z}", shell=True) log(return_code) all_times.append(time.time() - single_time) log("single time: --- %s seconds ---" % (time.time() - single_time)) log("time so far: --- %s seconds ---" % (time.time() - start_time)) log('\n') log(all_times) log("--- %s seconds ---" % (time.time() - start_time))