|
import copy |
|
import torch |
|
from io import BytesIO |
|
from openbabel import openbabel |
|
from torch_geometric.utils import to_networkx |
|
from torch_geometric.data import Data |
|
from torch_scatter import scatter |
|
from rdkit import Chem |
|
from rdkit.Chem.rdchem import Mol, HybridizationType, BondType |
|
from rdkit.Chem.rdchem import BondType as BT |
|
|
|
|
|
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} |
|
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} |
|
|
|
|
|
|
|
def rdmol_to_data(mol, smiles=None): |
|
assert mol.GetNumConformers() == 1 |
|
N = mol.GetNumAtoms() |
|
|
|
pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32) |
|
|
|
atomic_number = [] |
|
aromatic = [] |
|
sp = [] |
|
sp2 = [] |
|
sp3 = [] |
|
num_hs = [] |
|
for atom in mol.GetAtoms(): |
|
atomic_number.append(atom.GetAtomicNum()) |
|
aromatic.append(1 if atom.GetIsAromatic() else 0) |
|
hybridization = atom.GetHybridization() |
|
sp.append(1 if hybridization == HybridizationType.SP else 0) |
|
sp2.append(1 if hybridization == HybridizationType.SP2 else 0) |
|
sp3.append(1 if hybridization == HybridizationType.SP3 else 0) |
|
|
|
z = torch.tensor(atomic_number, dtype=torch.long) |
|
|
|
row, col, edge_type = [], [], [] |
|
for bond in mol.GetBonds(): |
|
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
|
row += [start, end] |
|
col += [end, start] |
|
edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] |
|
|
|
edge_index = torch.tensor([row, col], dtype=torch.long) |
|
edge_type = torch.tensor(edge_type) |
|
|
|
perm = (edge_index[0] * N + edge_index[1]).argsort() |
|
edge_index = edge_index[:, perm] |
|
edge_type = edge_type[perm] |
|
|
|
row, col = edge_index |
|
hs = (z == 1).to(torch.float32) |
|
|
|
num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() |
|
|
|
if smiles is None: |
|
smiles = Chem.MolToSmiles(Chem.RemoveHs(mol)) |
|
|
|
data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, |
|
rdmol=copy.deepcopy(mol), smiles=smiles) |
|
data.nx = to_networkx(data, to_undirected=True) |
|
|
|
return data |
|
|
|
|
|
def generated_to_xyz(data): |
|
ptable = Chem.GetPeriodicTable() |
|
|
|
num_atoms = data.ligand_context_element.size(0) |
|
xyz = "%d\n\n" % (num_atoms, ) |
|
for i in range(num_atoms): |
|
symb = ptable.GetElementSymbol(data.ligand_context_element[i].item()) |
|
x, y, z = data.ligand_context_pos[i].clone().cpu().tolist() |
|
xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) |
|
|
|
return xyz |
|
|
|
|
|
def generated_to_sdf(data): |
|
xyz = generated_to_xyz(data) |
|
obConversion = openbabel.OBConversion() |
|
obConversion.SetInAndOutFormats("xyz", "sdf") |
|
|
|
mol = openbabel.OBMol() |
|
obConversion.ReadString(mol, xyz) |
|
sdf = obConversion.WriteString(mol) |
|
return sdf |
|
|
|
|
|
def sdf_to_rdmol(sdf): |
|
stream = BytesIO(sdf.encode()) |
|
suppl = Chem.ForwardSDMolSupplier(stream) |
|
for mol in suppl: |
|
return mol |
|
return None |
|
|
|
def generated_to_rdmol(data): |
|
sdf = generated_to_sdf(data) |
|
return sdf_to_rdmol(sdf) |
|
|
|
|
|
def filter_rd_mol(rdmol): |
|
ring_info = rdmol.GetRingInfo() |
|
ring_info.AtomRings() |
|
rings = [set(r) for r in ring_info.AtomRings()] |
|
|
|
|
|
for i, ring_a in enumerate(rings): |
|
if len(ring_a) != 3:continue |
|
for j, ring_b in enumerate(rings): |
|
if i <= j: continue |
|
inter = ring_a.intersection(ring_b) |
|
if (len(ring_b) == 3) and (len(inter) > 0): |
|
return False |
|
|
|
return True |