Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
import numpy as np | |
from rdkit import Chem, Geometry | |
from src import const | |
def create_conformer(coords): | |
conformer = Chem.Conformer() | |
for i, (x, y, z) in enumerate(coords): | |
conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z)) | |
return conformer | |
def build_molecules(one_hot, x, node_mask, is_geom, margins=const.MARGINS_EDM): | |
molecules = [] | |
for i in range(len(one_hot)): | |
mask = node_mask[i].squeeze() == 1 | |
atom_types = one_hot[i][mask].argmax(dim=1).detach().cpu() | |
positions = x[i][mask].detach().cpu() | |
mol = build_molecule(positions, atom_types, is_geom, margins=margins) | |
molecules.append(mol) | |
return molecules | |
def build_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM): | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
X, A, E = build_xae_molecule(positions, atom_types, is_geom=is_geom, margins=margins) | |
mol = Chem.RWMol() | |
for atom in X: | |
a = Chem.Atom(idx2atom[atom.item()]) | |
mol.AddAtom(a) | |
all_bonds = torch.nonzero(A) | |
for bond in all_bonds: | |
mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()]) | |
mol.AddConformer(create_conformer(positions.detach().cpu().numpy().astype(np.float64))) | |
return mol | |
def build_xae_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM): | |
""" Returns a triplet (X, A, E): atom_types, adjacency matrix, edge_types | |
args: | |
positions: N x 3 (already masked to keep final number nodes) | |
atom_types: N | |
returns: | |
X: N (int) | |
A: N x N (bool) (binary adjacency matrix) | |
E: N x N (int) (bond type, 0 if no bond) such that A = E.bool() | |
""" | |
n = positions.shape[0] | |
X = atom_types | |
A = torch.zeros((n, n), dtype=torch.bool) | |
E = torch.zeros((n, n), dtype=torch.int) | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
pos = positions.unsqueeze(0) | |
dists = torch.cdist(pos, pos, p=2).squeeze(0) | |
for i in range(n): | |
for j in range(i): | |
pair = sorted([atom_types[i], atom_types[j]]) | |
order = get_bond_order(idx2atom[pair[0].item()], idx2atom[pair[1].item()], dists[i, j], margins=margins) | |
# TODO: a batched version of get_bond_order to avoid the for loop | |
if order > 0: | |
# Warning: the graph should be DIRECTED | |
A[i, j] = 1 | |
E[i, j] = order | |
return X, A, E | |
def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM): | |
distance = 100 * distance # We change the metric | |
# Check exists for large molecules where some atom pairs do not have a | |
# typical bond length. | |
if check_exists: | |
if atom1 not in const.BONDS_1: | |
return 0 | |
if atom2 not in const.BONDS_1[atom1]: | |
return 0 | |
# margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples | |
if distance < const.BONDS_1[atom1][atom2] + margins[0]: | |
# Check if atoms in bonds2 dictionary. | |
if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]: | |
thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1] | |
if distance < thr_bond2: | |
if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]: | |
thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2] | |
if distance < thr_bond3: | |
return 3 # Triple | |
return 2 # Double | |
return 1 # Single | |
return 0 # No bond | |