import numpy as np import os.path import subprocess import torch from Bio.PDB import PDBParser from src import const from src.visualizer import save_xyz_file from src.utils import FoundNaNException from src.datasets import get_one_hot def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0): chain = node_mask = None for i in range(5): try: chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) break except FoundNaNException: continue print('Generated linker') x = chain[0][:, :, :ddpm.n_dims] h = chain[0][:, :, ddpm.n_dims:] # Put the molecule back to the initial orientation if with_pocket: com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] else: com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] pos_masked = data['positions'] * com_mask N = com_mask.sum(1, keepdims=True) mean = torch.sum(pos_masked, dim=1, keepdim=True) / N x = x + mean * node_mask if with_pocket: node_mask[torch.where(data['pocket_mask'])] = 0 batch_size = len(data['positions']) names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)] save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') print('Saved XYZ files') def try_to_convert_to_sdf(name, num_samples): out_files = [] for i in range(num_samples): out_xyz = f'results/output_{i + 1}_{name}_.xyz' out_sdf = f'results/output_{i + 1}_{name}_.sdf' subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) if os.path.exists(out_sdf): out_files.append(out_sdf) else: out_files.append(out_xyz) return out_files def get_pocket(mol, pdb_path): struct = PDBParser().get_structure('', pdb_path) residue_ids = [] atom_coords = [] for residue in struct.get_residues(): resid = residue.get_id()[1] for atom in residue.get_atoms(): atom_coords.append(atom.get_coord()) residue_ids.append(resid) residue_ids = np.array(residue_ids) atom_coords = np.array(atom_coords) mol_atom_coords = mol.GetConformer().GetPositions() distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) pocket_coords_full = [] pocket_types_full = [] pocket_coords_bb = [] pocket_types_bb = [] for residue in struct.get_residues(): resid = residue.get_id()[1] if resid not in contact_residues: continue for atom in residue.get_atoms(): atom_name = atom.get_name() atom_type = atom.element.upper() atom_coord = atom.get_coord() pocket_coords_full.append(atom_coord.tolist()) pocket_types_full.append(atom_type) if atom_name in {'N', 'CA', 'C', 'O'}: pocket_coords_bb.append(atom_coord.tolist()) pocket_types_bb.append(atom_type) pocket_pos = [] pocket_one_hot = [] pocket_charges = [] for coord, atom_type in zip(pocket_coords_full, pocket_types_full): if atom_type not in const.GEOM_ATOM2IDX.keys(): continue pocket_pos.append(coord) pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) pocket_charges.append(const.GEOM_CHARGES[atom_type]) pocket_pos = np.array(pocket_pos) pocket_one_hot = np.array(pocket_one_hot) pocket_charges = np.array(pocket_charges) return pocket_pos, pocket_one_hot, pocket_charges