DiffLinker / src /molecule_builder.py
igashov
DiffLinker code
95ba5bc
raw
history blame contribute delete
No virus
3.63 kB
import torch
import numpy as np
from rdkit import Chem, Geometry
from src import const
def create_conformer(coords):
conformer = Chem.Conformer()
for i, (x, y, z) in enumerate(coords):
conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z))
return conformer
def build_molecules(one_hot, x, node_mask, is_geom, margins=const.MARGINS_EDM):
molecules = []
for i in range(len(one_hot)):
mask = node_mask[i].squeeze() == 1
atom_types = one_hot[i][mask].argmax(dim=1).detach().cpu()
positions = x[i][mask].detach().cpu()
mol = build_molecule(positions, atom_types, is_geom, margins=margins)
molecules.append(mol)
return molecules
def build_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
X, A, E = build_xae_molecule(positions, atom_types, is_geom=is_geom, margins=margins)
mol = Chem.RWMol()
for atom in X:
a = Chem.Atom(idx2atom[atom.item()])
mol.AddAtom(a)
all_bonds = torch.nonzero(A)
for bond in all_bonds:
mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()])
mol.AddConformer(create_conformer(positions.detach().cpu().numpy().astype(np.float64)))
return mol
def build_xae_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM):
""" Returns a triplet (X, A, E): atom_types, adjacency matrix, edge_types
args:
positions: N x 3 (already masked to keep final number nodes)
atom_types: N
returns:
X: N (int)
A: N x N (bool) (binary adjacency matrix)
E: N x N (int) (bond type, 0 if no bond) such that A = E.bool()
"""
n = positions.shape[0]
X = atom_types
A = torch.zeros((n, n), dtype=torch.bool)
E = torch.zeros((n, n), dtype=torch.int)
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
pos = positions.unsqueeze(0)
dists = torch.cdist(pos, pos, p=2).squeeze(0)
for i in range(n):
for j in range(i):
pair = sorted([atom_types[i], atom_types[j]])
order = get_bond_order(idx2atom[pair[0].item()], idx2atom[pair[1].item()], dists[i, j], margins=margins)
# TODO: a batched version of get_bond_order to avoid the for loop
if order > 0:
# Warning: the graph should be DIRECTED
A[i, j] = 1
E[i, j] = order
return X, A, E
def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM):
distance = 100 * distance # We change the metric
# Check exists for large molecules where some atom pairs do not have a
# typical bond length.
if check_exists:
if atom1 not in const.BONDS_1:
return 0
if atom2 not in const.BONDS_1[atom1]:
return 0
# margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples
if distance < const.BONDS_1[atom1][atom2] + margins[0]:
# Check if atoms in bonds2 dictionary.
if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]:
thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1]
if distance < thr_bond2:
if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]:
thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2]
if distance < thr_bond3:
return 3 # Triple
return 2 # Double
return 1 # Single
return 0 # No bond