import time import multiprocessing from multiprocessing import Pool import torch import numpy as np from moleculekit.molecule import Molecule from import getVoxelDescriptors from import prepareProteinForAtomtyping from import systemPrepare class AtomtypingError(Exception): pass class StructureCleaningError(Exception): pass class ProteinPrepareError(Exception): pass class VoxelizationError(Exception): pass metal_atypes = ( "MG", "ZN", "MN", "CA", "FE", "HG", "CD", "NI", "CO", "CU", "K", "LI", "Mg", "Zn", "Mn", "Ca", "Fe", "Hg", "Cd", "Ni", "Co", "Cu", "Li", ) def voxelize_single_notcentered(env): """voxelize 1 structure, executed on a single CPU Using 7 of the 8 channels supplied by moleculekit(excluding metals) Additionally it uses all the metalbinding residues as channel Parameters ---------- env : tuple Tuple of the form (prot, idx) Returns ------- voxels : torch.tensor Voxelized structure with 8 channels (8,20,20,20) prot_centers : list List of the centers of the voxels (20x20x20,3) prot_n : list List of the number of voxels in each voxel (20x20x20) prot : moleculekit.Molecule Moleculekit molecule """ prot, id = env c = prot.get("coords", sel=f"index {id} and name CA") size = [16, 16, 16] # size of box voxels = torch.zeros(8, 32, 32, 32) try: hydrophobic = prot.atomselect("element C") hydrophobic = hydrophobic.reshape(hydrophobic.shape[0], 1) aromatic = prot.atomselect( "resname HIS HIE HIP HID TRP TYR PHE and sidechain and not name CB and not hydrogen" ) aromatic = aromatic.reshape(aromatic.shape[0], 1) metalcoordination = prot.atomselect( "(name ND1 NE2 SG OE1 OE2 OD2) or (protein and name O N)" ) metalcoordination = metalcoordination.reshape(metalcoordination.shape[0], 1) hbondacceptor = prot.atomselect( "(resname ASP GLU HIS HIE HIP HID SER THR MSE CYS MET and name ND2 NE2 OE1 OE2 OD1 OD2 OG OG1 SE SG) or name O" ) hbondacceptor = hbondacceptor.reshape(metalcoordination.shape[0], 1) hbonddonor = prot.atomselect( "(resname ASN GLN ASH GLH TRP MSE SER THR MET CYS and name ND2 NE2 NE1 SG SE OG OG1) or name N" ) hbonddonor = hbonddonor.reshape(metalcoordination.shape[0], 1) positive = prot.atomselect( "resname LYS ARG HIS HIE HIP HID and name NZ NH1 NH2 ND1 NE2 NE" ) positive = positive.reshape(positive.shape[0], 1) negative = prot.atomselect("(resname ASP GLU ASH GLH and name OD1 OD2 OE1 OE2)") negative = negative.reshape(negative.shape[0], 1) occupancy = prot.atomselect("protein and not hydrogen") occupancy = occupancy.reshape(occupancy.shape[0], 1) userchannels = np.hstack( [ hydrophobic, aromatic, metalcoordination, hbondacceptor, hbonddonor, positive, negative, occupancy, ] ) prot_vox, prot_centers, prot_N = getVoxelDescriptors( prot, center=c, userchannels=userchannels, boxsize=size, voxelsize=0.5, validitychecks=False, ) except: raise VoxelizationError(f"voxelization of {id} failed") nchannels = prot_vox.shape[1] prot_vox_t = ( prot_vox.transpose() .reshape([1, nchannels, prot_N[0], prot_N[1], prot_N[2]]) .copy() ) voxels = torch.from_numpy(prot_vox_t) return (voxels, prot_centers, prot_N, prot.copy()) def processStructures(pdb_file, resids, clean=True): """Process a pdb file and return a list of voxelized boxes centered on the residues Parameters ---------- pdb_file : str Path to pdb file resids : list List of resids to center the voxels on clean : bool If True, remove all non-protein residues from the pdb file Returns ------- voxels : torch.Tensor Voxelized boxes with 8 channels (N, 8,32,32,32) prot_centers_list : list List of the centers of the voxels (N*32**32*32,3) prot_n_list : list List of the number of voxels in each box (N,3) envs: list List of tuples (prot, idx) (N) """ start_time_processing = time.time() # load molecule using MoleculeKit try: prot = Molecule(pdb_file) except: raise IOError("could not read pdbfile") if clean: prot.filter("protein and not hydrogen") environments = [] for idx in resids: try: environments.append((prot.copy(), idx)) except: print("ignoring " + idx) prot_centers_list = [] prot_n_list = [] envs = [] results = [voxelize_single_notcentered(x) for x in environments] device = "cuda" if torch.cuda.is_available() else "cpu" voxels = torch.empty(len(results), 8, 32, 32, 32, device=device) vox_env, prot_centers_list, prot_n_list, envs = zip(*results) for i, vox_env in enumerate(vox_env): voxels[i] = vox_env print(f"Voxelization took {time.time() - start_time_processing:.3f} seconds ") return voxels, prot_centers_list, prot_n_list, envs