File size: 3,773 Bytes
c104a99
92263a6
 
 
 
c104a99
 
92263a6
c104a99
 
92263a6
 
ff512d8
c104a99
 
 
 
 
 
 
 
92263a6
 
 
 
 
c104a99
 
 
 
 
 
 
92263a6
 
 
c104a99
 
 
d8600ba
ff512d8
92263a6
 
 
 
abdd514
92263a6
abdd514
92263a6
 
 
 
 
 
 
 
 
c104a99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import os.path
import subprocess
import torch

from Bio.PDB import PDBParser
from src import const
from src.visualizer import save_xyz_file
from src.utils import FoundNaNException
from src.datasets import get_one_hot


def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0):
    chain = node_mask = None
    for i in range(5):
        try:
            chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
            break
        except FoundNaNException:
            continue

    print('Generated linker')
    x = chain[0][:, :, :ddpm.n_dims]
    h = chain[0][:, :, ddpm.n_dims:]

    # Put the molecule back to the initial orientation
    if with_pocket:
        com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
    else:
        com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']

    pos_masked = data['positions'] * com_mask
    N = com_mask.sum(1, keepdims=True)
    mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
    x = x + mean * node_mask

    if with_pocket:
        node_mask[torch.where(data['pocket_mask'])] = 0

    batch_size = len(data['positions'])
    names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
    save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
    print('Saved XYZ files')


def try_to_convert_to_sdf(name, num_samples):
    out_files = []
    for i in range(num_samples):
        out_xyz = f'results/output_{i + 1}_{name}_.xyz'
        out_sdf = f'results/output_{i + 1}_{name}_.sdf'
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        if os.path.exists(out_sdf):
            out_files.append(out_sdf)
        else:
            out_files.append(out_xyz)

    return out_files


def get_pocket(mol, pdb_path):
    struct = PDBParser().get_structure('', pdb_path)
    residue_ids = []
    atom_coords = []

    for residue in struct.get_residues():
        resid = residue.get_id()[1]
        for atom in residue.get_atoms():
            atom_coords.append(atom.get_coord())
            residue_ids.append(resid)

    residue_ids = np.array(residue_ids)
    atom_coords = np.array(atom_coords)
    mol_atom_coords = mol.GetConformer().GetPositions()

    distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
    contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])

    pocket_coords_full = []
    pocket_types_full = []

    pocket_coords_bb = []
    pocket_types_bb = []

    for residue in struct.get_residues():
        resid = residue.get_id()[1]
        if resid not in contact_residues:
            continue

        for atom in residue.get_atoms():
            atom_name = atom.get_name()
            atom_type = atom.element.upper()
            atom_coord = atom.get_coord()

            pocket_coords_full.append(atom_coord.tolist())
            pocket_types_full.append(atom_type)

            if atom_name in {'N', 'CA', 'C', 'O'}:
                pocket_coords_bb.append(atom_coord.tolist())
                pocket_types_bb.append(atom_type)

    pocket_pos = []
    pocket_one_hot = []
    pocket_charges = []
    for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
        if atom_type not in const.GEOM_ATOM2IDX.keys():
            continue

        pocket_pos.append(coord)
        pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
        pocket_charges.append(const.GEOM_CHARGES[atom_type])

    pocket_pos = np.array(pocket_pos)
    pocket_one_hot = np.array(pocket_one_hot)
    pocket_charges = np.array(pocket_charges)

    return pocket_pos, pocket_one_hot, pocket_charges