DiffLinker / src /generation.py
igashov's picture
Pocket-conditioned generation
c104a99
raw
history blame
No virus
3.71 kB
import numpy as np
import os.path
import subprocess
import torch
from Bio.PDB import PDBParser
from src import const
from src.visualizer import save_xyz_file
from src.utils import FoundNaNException
from src.datasets import get_one_hot
N_SAMPLES = 5
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
chain = node_mask = None
for i in range(5):
try:
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
break
except FoundNaNException:
continue
print('Generated linker')
x = chain[0][:, :, :ddpm.n_dims]
h = chain[0][:, :, ddpm.n_dims:]
# Put the molecule back to the initial orientation
if with_pocket:
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
else:
com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
pos_masked = data['positions'] * com_mask
N = com_mask.sum(1, keepdims=True)
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
x = x + mean * node_mask
if with_pocket:
node_mask[torch.where(data['pocket_mask'])] = 0
names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
print('Saved XYZ files')
def try_to_convert_to_sdf(name):
out_files = []
for i in range(N_SAMPLES):
out_xyz = f'results/output_{i + 1}_{name}_.xyz'
out_sdf = f'results/output_{i + 1}_{name}_.sdf'
subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
if os.path.exists(out_sdf):
out_files.append(out_sdf)
else:
out_files.append(out_xyz)
return out_files
def get_pocket(mol, pdb_path):
struct = PDBParser().get_structure('', pdb_path)
residue_ids = []
atom_coords = []
for residue in struct.get_residues():
resid = residue.get_id()[1]
for atom in residue.get_atoms():
atom_coords.append(atom.get_coord())
residue_ids.append(resid)
residue_ids = np.array(residue_ids)
atom_coords = np.array(atom_coords)
mol_atom_coords = mol.GetConformer().GetPositions()
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
pocket_coords_full = []
pocket_types_full = []
pocket_coords_bb = []
pocket_types_bb = []
for residue in struct.get_residues():
resid = residue.get_id()[1]
if resid not in contact_residues:
continue
for atom in residue.get_atoms():
atom_name = atom.get_name()
atom_type = atom.element.upper()
atom_coord = atom.get_coord()
pocket_coords_full.append(atom_coord.tolist())
pocket_types_full.append(atom_type)
if atom_name in {'N', 'CA', 'C', 'O'}:
pocket_coords_bb.append(atom_coord.tolist())
pocket_types_bb.append(atom_type)
pocket_pos = []
pocket_one_hot = []
pocket_charges = []
for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
if atom_type not in const.GEOM_ATOM2IDX.keys():
continue
pocket_pos.append(coord)
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
pocket_charges.append(const.GEOM_CHARGES[atom_type])
pocket_pos = np.array(pocket_pos)
pocket_one_hot = np.array(pocket_one_hot)
pocket_charges = np.array(pocket_charges)
return pocket_pos, pocket_one_hot, pocket_charges