ICLR_FLAG / utils /chem.py
zaixizhang
renew
10efe81
raw
history blame
3.5 kB
import copy
import torch
from io import BytesIO
from openbabel import openbabel
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
from torch_scatter import scatter
from rdkit import Chem
from rdkit.Chem.rdchem import Mol, HybridizationType, BondType
from rdkit.Chem.rdchem import BondType as BT
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
def rdmol_to_data(mol, smiles=None):
assert mol.GetNumConformers() == 1
N = mol.GetNumAtoms()
pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32)
atomic_number = []
aromatic = []
sp = []
sp2 = []
sp3 = []
num_hs = []
for atom in mol.GetAtoms():
atomic_number.append(atom.GetAtomicNum())
aromatic.append(1 if atom.GetIsAromatic() else 0)
hybridization = atom.GetHybridization()
sp.append(1 if hybridization == HybridizationType.SP else 0)
sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
z = torch.tensor(atomic_number, dtype=torch.long)
row, col, edge_type = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_type = torch.tensor(edge_type)
perm = (edge_index[0] * N + edge_index[1]).argsort()
edge_index = edge_index[:, perm]
edge_type = edge_type[perm]
row, col = edge_index
hs = (z == 1).to(torch.float32)
num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
if smiles is None:
smiles = Chem.MolToSmiles(Chem.RemoveHs(mol))
data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type,
rdmol=copy.deepcopy(mol), smiles=smiles)
data.nx = to_networkx(data, to_undirected=True)
return data
def generated_to_xyz(data):
ptable = Chem.GetPeriodicTable()
num_atoms = data.ligand_context_element.size(0)
xyz = "%d\n\n" % (num_atoms, )
for i in range(num_atoms):
symb = ptable.GetElementSymbol(data.ligand_context_element[i].item())
x, y, z = data.ligand_context_pos[i].clone().cpu().tolist()
xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
return xyz
def generated_to_sdf(data):
xyz = generated_to_xyz(data)
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("xyz", "sdf")
mol = openbabel.OBMol()
obConversion.ReadString(mol, xyz)
sdf = obConversion.WriteString(mol)
return sdf
def sdf_to_rdmol(sdf):
stream = BytesIO(sdf.encode())
suppl = Chem.ForwardSDMolSupplier(stream)
for mol in suppl:
return mol
return None
def generated_to_rdmol(data):
sdf = generated_to_sdf(data)
return sdf_to_rdmol(sdf)
def filter_rd_mol(rdmol):
ring_info = rdmol.GetRingInfo()
ring_info.AtomRings()
rings = [set(r) for r in ring_info.AtomRings()]
# 3-3 ring intersection
for i, ring_a in enumerate(rings):
if len(ring_a) != 3:continue
for j, ring_b in enumerate(rings):
if i <= j: continue
inter = ring_a.intersection(ring_b)
if (len(ring_b) == 3) and (len(inter) > 0):
return False
return True