metal3d / utils /voxelization.py
simonduerr's picture
Update utils/voxelization.py
61855d2
import time
import multiprocessing
from multiprocessing import Pool
import torch
import numpy as np
from moleculekit.molecule import Molecule
from moleculekit.tools.voxeldescriptors import getVoxelDescriptors
from moleculekit.tools.atomtyper import prepareProteinForAtomtyping
from moleculekit.tools.preparation import systemPrepare
class AtomtypingError(Exception):
pass
class StructureCleaningError(Exception):
pass
class ProteinPrepareError(Exception):
pass
class VoxelizationError(Exception):
pass
metal_atypes = (
"MG",
"ZN",
"MN",
"CA",
"FE",
"HG",
"CD",
"NI",
"CO",
"CU",
"K",
"LI",
"Mg",
"Zn",
"Mn",
"Ca",
"Fe",
"Hg",
"Cd",
"Ni",
"Co",
"Cu",
"Li",
)
def voxelize_single_notcentered(env):
"""voxelize 1 structure, executed on a single CPU
Using 7 of the 8 channels supplied by moleculekit(excluding metals)
Additionally it uses all the metalbinding residues as channel
Parameters
----------
env : tuple
Tuple of the form (prot, idx)
Returns
-------
voxels : torch.tensor
Voxelized structure with 8 channels (8,20,20,20)
prot_centers : list
List of the centers of the voxels (20x20x20,3)
prot_n : list
List of the number of voxels in each voxel (20x20x20)
prot : moleculekit.Molecule
Moleculekit molecule
"""
prot, id = env
c = prot.get("coords", sel=f"index {id} and name CA")
size = [16, 16, 16] # size of box
voxels = torch.zeros(8, 32, 32, 32)
try:
hydrophobic = prot.atomselect("element C")
hydrophobic = hydrophobic.reshape(hydrophobic.shape[0], 1)
aromatic = prot.atomselect(
"resname HIS HIE HIP HID TRP TYR PHE and sidechain and not name CB and not hydrogen"
)
aromatic = aromatic.reshape(aromatic.shape[0], 1)
metalcoordination = prot.atomselect(
"(name ND1 NE2 SG OE1 OE2 OD2) or (protein and name O N)"
)
metalcoordination = metalcoordination.reshape(metalcoordination.shape[0], 1)
hbondacceptor = prot.atomselect(
"(resname ASP GLU HIS HIE HIP HID SER THR MSE CYS MET and name ND2 NE2 OE1 OE2 OD1 OD2 OG OG1 SE SG) or name O"
)
hbondacceptor = hbondacceptor.reshape(metalcoordination.shape[0], 1)
hbonddonor = prot.atomselect(
"(resname ASN GLN ASH GLH TRP MSE SER THR MET CYS and name ND2 NE2 NE1 SG SE OG OG1) or name N"
)
hbonddonor = hbonddonor.reshape(metalcoordination.shape[0], 1)
positive = prot.atomselect(
"resname LYS ARG HIS HIE HIP HID and name NZ NH1 NH2 ND1 NE2 NE"
)
positive = positive.reshape(positive.shape[0], 1)
negative = prot.atomselect("(resname ASP GLU ASH GLH and name OD1 OD2 OE1 OE2)")
negative = negative.reshape(negative.shape[0], 1)
occupancy = prot.atomselect("protein and not hydrogen")
occupancy = occupancy.reshape(occupancy.shape[0], 1)
userchannels = np.hstack(
[
hydrophobic,
aromatic,
metalcoordination,
hbondacceptor,
hbonddonor,
positive,
negative,
occupancy,
]
)
prot_vox, prot_centers, prot_N = getVoxelDescriptors(
prot,
center=c,
userchannels=userchannels,
boxsize=size,
voxelsize=0.5,
validitychecks=False,
)
except:
raise VoxelizationError(f"voxelization of {id} failed")
nchannels = prot_vox.shape[1]
prot_vox_t = (
prot_vox.transpose()
.reshape([1, nchannels, prot_N[0], prot_N[1], prot_N[2]])
.copy()
)
voxels = torch.from_numpy(prot_vox_t)
return (voxels, prot_centers, prot_N, prot.copy())
def processStructures(pdb_file, resids, clean=True):
"""Process a pdb file and return a list of voxelized boxes centered on the residues
Parameters
----------
pdb_file : str
Path to pdb file
resids : list
List of resids to center the voxels on
clean : bool
If True, remove all non-protein residues from the pdb file
Returns
-------
voxels : torch.Tensor
Voxelized boxes with 8 channels (N, 8,32,32,32)
prot_centers_list : list
List of the centers of the voxels (N*32**32*32,3)
prot_n_list : list
List of the number of voxels in each box (N,3)
envs: list
List of tuples (prot, idx) (N)
"""
start_time_processing = time.time()
# load molecule using MoleculeKit
try:
prot = Molecule('files/2CBA.pdb')
except:
raise IOError("could not read pdbfile")
if clean:
prot.filter("protein and not hydrogen")
environments = []
for idx in resids:
try:
environments.append((prot.copy(), idx))
except:
print("ignoring " + idx)
prot_centers_list = []
prot_n_list = []
envs = []
results = [voxelize_single_notcentered(x) for x in environments]
voxels = torch.empty(len(results), 8, 32, 32, 32, device="cpu")
vox_env, prot_centers_list, prot_n_list, envs = zip(*results)
for i, vox_env in enumerate(vox_env):
voxels[i] = vox_env
print(f"Voxelization took {time.time() - start_time_processing:.3f} seconds ")
return voxels, prot_centers_list, prot_n_list, envs