Spaces:
Sleeping
Sleeping
import copy | |
import traceback | |
import numpy as np | |
import multiprocessing | |
import itertools | |
import rdkit | |
import rdkit.Chem as Chem | |
rdkit.RDLogger.DisableLog('rdApp.*') | |
from SmilesPE.pretokenizer import atomwise_tokenizer | |
from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX | |
def is_valid_mol(s, format_='atomtok'): | |
if format_ == 'atomtok': | |
mol = Chem.MolFromSmiles(s) | |
elif format_ == 'inchi': | |
if not s.startswith('InChI=1S'): | |
s = f"InChI=1S/{s}" | |
mol = Chem.MolFromInchi(s) | |
else: | |
raise NotImplemented | |
return mol is not None | |
def _convert_smiles_to_inchi(smiles): | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
inchi = Chem.MolToInchi(mol) | |
except: | |
inchi = None | |
return inchi | |
def convert_smiles_to_inchi(smiles_list, num_workers=16): | |
with multiprocessing.Pool(num_workers) as p: | |
inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128) | |
n_success = sum([x is not None for x in inchi_list]) | |
r_success = n_success / len(inchi_list) | |
inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list] | |
return inchi_list, r_success | |
def merge_inchi(inchi1, inchi2): | |
replaced = 0 | |
inchi1 = copy.deepcopy(inchi1) | |
for i in range(len(inchi1)): | |
if inchi1[i] == 'InChI=1S/H2O/h1H2': | |
inchi1[i] = inchi2[i] | |
replaced += 1 | |
return inchi1, replaced | |
def _get_num_atoms(smiles): | |
try: | |
return Chem.MolFromSmiles(smiles).GetNumAtoms() | |
except: | |
return 0 | |
def get_num_atoms(smiles, num_workers=16): | |
if type(smiles) is str: | |
return _get_num_atoms(smiles) | |
with multiprocessing.Pool(num_workers) as p: | |
num_atoms = p.map(_get_num_atoms, smiles) | |
return num_atoms | |
def normalize_nodes(nodes, flip_y=True): | |
x, y = nodes[:, 0], nodes[:, 1] | |
minx, maxx = min(x), max(x) | |
miny, maxy = min(y), max(y) | |
x = (x - minx) / max(maxx - minx, 1e-6) | |
if flip_y: | |
y = (maxy - y) / max(maxy - miny, 1e-6) | |
else: | |
y = (y - miny) / max(maxy - miny, 1e-6) | |
return np.stack([x, y], axis=1) | |
def _verify_chirality(mol, coords, symbols, edges, debug=False): | |
try: | |
n = mol.GetNumAtoms() | |
# Make a temp mol to find chiral centers | |
mol_tmp = mol.GetMol() | |
Chem.SanitizeMol(mol_tmp) | |
chiral_centers = Chem.FindMolChiralCenters( | |
mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) | |
chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int] | |
# correction to clear pre-condition violation (for some corner cases) | |
for bond in mol.GetBonds(): | |
if bond.GetBondType() == Chem.BondType.SINGLE: | |
bond.SetBondDir(Chem.BondDir.NONE) | |
# Create conformer from 2D coordinate | |
conf = Chem.Conformer(n) | |
conf.Set3D(True) | |
for i, (x, y) in enumerate(coords): | |
conf.SetAtomPosition(i, (x, 1 - y, 0)) | |
mol.AddConformer(conf) | |
Chem.SanitizeMol(mol) | |
Chem.AssignStereochemistryFrom3D(mol) | |
# NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z | |
# So we do this first, remove the conformer and add back the 2D conformer for chiral correction | |
mol.RemoveAllConformers() | |
conf = Chem.Conformer(n) | |
conf.Set3D(False) | |
for i, (x, y) in enumerate(coords): | |
conf.SetAtomPosition(i, (x, 1 - y, 0)) | |
mol.AddConformer(conf) | |
# Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE. | |
Chem.SanitizeMol(mol) | |
Chem.AssignChiralTypesFromBondDirs(mol) | |
Chem.AssignStereochemistry(mol, force=True) | |
# Second loop to reset any wedge/dash bond to be starting from the chiral center) | |
for i in chiral_center_ids: | |
for j in range(n): | |
if edges[i][j] == 5: | |
# assert edges[j][i] == 6 | |
mol.RemoveBond(i, j) | |
mol.AddBond(i, j, Chem.BondType.SINGLE) | |
mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE) | |
elif edges[i][j] == 6: | |
# assert edges[j][i] == 5 | |
mol.RemoveBond(i, j) | |
mol.AddBond(i, j, Chem.BondType.SINGLE) | |
mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH) | |
Chem.AssignChiralTypesFromBondDirs(mol) | |
Chem.AssignStereochemistry(mol, force=True) | |
# reset chiral tags for non-carbon atom | |
for atom in mol.GetAtoms(): | |
if atom.GetSymbol() != "C": | |
atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) | |
mol = mol.GetMol() | |
except Exception as e: | |
if debug: | |
raise e | |
pass | |
return mol | |
def _parse_tokens(tokens: list): | |
""" | |
Parse tokens of condensed formula into list of pairs `(elt, num)` | |
where `num` is the multiplicity of the atom (or nested condensed formula) `elt` | |
Used by `_parse_formula`, which does the same thing but takes a formula in string form as input | |
""" | |
elements = [] | |
i = 0 | |
j = 0 | |
while i < len(tokens): | |
if tokens[i] == '(': | |
while j < len(tokens) and tokens[j] != ')': | |
j += 1 | |
elt = _parse_tokens(tokens[i + 1:j]) | |
else: | |
elt = tokens[i] | |
j += 1 | |
if j < len(tokens) and tokens[j].isnumeric(): | |
num = int(tokens[j]) | |
j += 1 | |
else: | |
num = 1 | |
elements.append((elt, num)) | |
i = j | |
return elements | |
def _parse_formula(formula: str): | |
""" | |
Parse condensed formula into list of pairs `(elt, num)` | |
where `num` is the subscript to the atom (or nested condensed formula) `elt` | |
Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)] | |
""" | |
tokens = FORMULA_REGEX.findall(formula) | |
# if ''.join(tokens) != formula: | |
# tokens = FORMULA_REGEX_BACKUP.findall(formula) | |
return _parse_tokens(tokens) | |
def _expand_carbon(elements: list): | |
""" | |
Given list of pairs `(elt, num)`, output single list of all atoms in order, | |
expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary | |
Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O']) | |
""" | |
expanded = [] | |
i = 0 | |
while i < len(elements): | |
elt, num = elements[i] | |
# expand carbon sequence | |
if elt == 'C' and num > 1 and i + 1 < len(elements): | |
next_elt, next_num = elements[i + 1] | |
quotient, remainder = next_num // num, next_num % num | |
for _ in range(num): | |
expanded.append('C') | |
for _ in range(quotient): | |
expanded.append(next_elt) | |
for _ in range(remainder): | |
expanded.append(next_elt) | |
i += 2 | |
# recurse if `elt` itself is a list (nested formula) | |
elif isinstance(elt, list): | |
new_elt = _expand_carbon(elt) | |
for _ in range(num): | |
expanded.append(new_elt) | |
i += 1 | |
# simplest case: simply append `elt` `num` times | |
else: | |
for _ in range(num): | |
expanded.append(elt) | |
i += 1 | |
return expanded | |
def _expand_abbreviation(abbrev): | |
""" | |
Expand abbreviation into its SMILES; also converts [Rn] to [n*] | |
Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula | |
""" | |
if abbrev in ABBREVIATIONS: | |
return ABBREVIATIONS[abbrev].smiles | |
if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): | |
if abbrev[1:].isdigit(): | |
return f'[{abbrev[1:]}*]' | |
return '*' | |
return f'[{abbrev}]' | |
def _get_bond_symb(bond_num): | |
""" | |
Get SMILES symbol for a bond given bond order | |
Used in `_condensed_formula_list_to_smiles` while writing the SMILES string | |
""" | |
if bond_num == 0: | |
return '.' | |
if bond_num == 1: | |
return '' | |
if bond_num == 2: | |
return '=' | |
if bond_num == 3: | |
return '#' | |
return '' | |
def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None): | |
""" | |
Converts condensed formula (in the form of a list of symbols) to smiles | |
Input: | |
`formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2 | |
`start_bond`: # bonds attached to beginning of formula | |
`end_bond`: # bonds attached to end of formula (deduce automatically if None) | |
`direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically) | |
Returns: | |
`smiles`: smiles corresponding to input condensed formula | |
`bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified | |
`num_trials`: number of trials | |
`success` (bool): whether conversion was successful | |
""" | |
# `direction` not specified: try left to right; if fails, try right to left | |
if direction is None: | |
num_trials = 1 | |
for dir_choice in [1, -1]: | |
smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice) | |
num_trials += trials | |
if success: | |
return smiles, bonds_left, num_trials, success | |
return None, None, num_trials, False | |
assert direction == 1 or direction == -1 | |
def dfs(smiles, bonds_left, cur_idx, add_idx): | |
""" | |
`smiles`: SMILES string so far | |
`cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached) | |
`cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far | |
`bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to | |
`add_idx`: index (in list `formula`) of atom to be attached to current atom | |
`add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far | |
Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2) | |
""" | |
num_trials = 1 | |
# end of formula: return result | |
if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1): | |
if end_bond is not None and end_bond != bonds_left: | |
return smiles, bonds_left, num_trials, False | |
return smiles, bonds_left, num_trials, True | |
# no more bonds but there are atoms remaining: conversion failed | |
if bonds_left <= 0: | |
return smiles, bonds_left, num_trials, False | |
to_add = formula_list[add_idx] # atom to be added to current atom | |
if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1 | |
if bonds_left > 1: | |
# "atom" added does not use up remaining bonds of current atom | |
# get smiles of "atom" (which is itself a condensed formula) | |
add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) | |
if val > 0: | |
add_str = _get_bond_symb(val + 1) + add_str | |
num_trials += trials | |
if not success: | |
return smiles, bonds_left, num_trials, False | |
# put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom | |
result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction) | |
else: | |
# "atom" added uses up remaining bonds of current atom | |
# get smiles of "atom" and bonds left on it | |
add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) | |
num_trials += trials | |
if not success: | |
return smiles, bonds_left, num_trials, False | |
# append smiles of "atom" (without parentheses) to smiles; it becomes new current atom | |
result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction) | |
smiles, bonds_left, trials, success = result | |
num_trials += trials | |
return smiles, bonds_left, num_trials, success | |
# atom added is a single symbol (as opposed to nested condensed formula) | |
for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added | |
add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation | |
if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom | |
if cur_idx >= 0: | |
add_str = _get_bond_symb(val) + add_str | |
result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction) | |
else: # atom added uses up remaining bonds of current atom; it becomes new current atom | |
if cur_idx >= 0: | |
add_str = _get_bond_symb(bonds_left) + add_str | |
result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction) | |
trials, success = result[2:] | |
num_trials += trials | |
if success: | |
return result[0], result[1], num_trials, success | |
if num_trials > 10000: | |
break | |
return smiles, bonds_left, num_trials, False | |
cur_idx = -1 if direction == 1 else len(formula_list) | |
add_idx = 0 if direction == 1 else len(formula_list) - 1 | |
return dfs('', start_bond, cur_idx, add_idx) | |
def get_smiles_from_symbol(symbol, mol, atom, bonds): | |
""" | |
Convert symbol (abbrev. or condensed formula) to smiles | |
If condensed formula, determine parsing direction and num. bonds on each side using coordinates | |
""" | |
if symbol in ABBREVIATIONS: | |
return ABBREVIATIONS[symbol].smiles | |
if len(symbol) > 20: | |
return None | |
total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds])) | |
formula_list = _expand_carbon(_parse_formula(symbol)) | |
smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) | |
if success: | |
return smiles | |
return None | |
def _replace_functional_group(smiles): | |
smiles = smiles.replace('<unk>', 'C') | |
for i, r in enumerate(RGROUP_SYMBOLS): | |
symbol = f'[{r}]' | |
if symbol in smiles: | |
if r[0] == 'R' and r[1:].isdigit(): | |
smiles = smiles.replace(symbol, f'[{int(r[1:])}*]') | |
else: | |
smiles = smiles.replace(symbol, '*') | |
# For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier. | |
tokens = atomwise_tokenizer(smiles) | |
new_tokens = [] | |
mappings = {} # isotope : symbol | |
isotope = 50 | |
for token in tokens: | |
if token[0] == '[': | |
if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None: | |
while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens: | |
isotope += 1 | |
placeholder = f'[{isotope}*]' | |
mappings[isotope] = token[1:-1] | |
new_tokens.append(placeholder) | |
continue | |
new_tokens.append(token) | |
smiles = ''.join(new_tokens) | |
return smiles, mappings | |
def convert_smiles_to_mol(smiles): | |
if smiles is None or smiles == '': | |
return None | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
except: | |
return None | |
return mol | |
BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE} | |
def _expand_functional_group(mol, mappings, debug=False): | |
def _need_expand(mol, mappings): | |
return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0 | |
if _need_expand(mol, mappings): | |
mol_w = Chem.RWMol(mol) | |
num_atoms = mol_w.GetNumAtoms() | |
for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons | |
atom.SetNumRadicalElectrons(0) | |
atoms_to_remove = [] | |
for i in range(num_atoms): | |
atom = mol_w.GetAtomWithIdx(i) | |
if atom.GetSymbol() == '*': | |
symbol = Chem.GetAtomAlias(atom) | |
isotope = atom.GetIsotope() | |
if isotope > 0 and isotope in mappings: | |
symbol = mappings[isotope] | |
if not (isinstance(symbol, str) and len(symbol) > 0): | |
continue | |
# rgroups do not need to be expanded | |
if symbol in RGROUP_SYMBOLS: | |
continue | |
bonds = atom.GetBonds() | |
sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds) | |
# create mol object for abbreviation/condensed formula from its SMILES | |
mol_r = convert_smiles_to_mol(sub_smiles) | |
if mol_r is None: | |
# atom.SetAtomicNum(6) | |
atom.SetIsotope(0) | |
continue | |
# remove bonds connected to abbreviation/condensed formula | |
adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds] | |
for adjacent_idx in adjacent_indices: | |
mol_w.RemoveBond(i, adjacent_idx) | |
adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices] | |
for adjacent_atom, bond in zip(adjacent_atoms, bonds): | |
adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble())) | |
# get indices of atoms of main body that connect to substituent | |
bonding_atoms_w = adjacent_indices | |
# assume indices are concated after combine mol_w and mol_r | |
bonding_atoms_r = [mol_w.GetNumAtoms()] | |
for atm in mol_r.GetAtoms(): | |
if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0: | |
bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx()) | |
# combine main body and substituent into a single molecule object | |
combo = Chem.CombineMols(mol_w, mol_r) | |
# connect substituent to main body with bonds | |
mol_w = Chem.RWMol(combo) | |
# if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body | |
for atm in bonding_atoms_w: | |
bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons() | |
mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order]) | |
# reset radical electrons | |
for atm in bonding_atoms_w: | |
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) | |
for atm in bonding_atoms_r: | |
mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) | |
atoms_to_remove.append(i) | |
# Remove atom in the end, otherwise the id will change | |
# Reverse the order and remove atoms with larger id first | |
atoms_to_remove.sort(reverse=True) | |
for i in atoms_to_remove: | |
mol_w.RemoveAtom(i) | |
smiles = Chem.MolToSmiles(mol_w) | |
mol = mol_w.GetMol() | |
else: | |
smiles = Chem.MolToSmiles(mol) | |
return smiles, mol | |
def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False): | |
mol = Chem.RWMol() | |
n = len(symbols) | |
ids = [] | |
for i in range(n): | |
symbol = symbols[i] | |
if symbol[0] == '[': | |
symbol = symbol[1:-1] | |
if symbol in RGROUP_SYMBOLS: | |
atom = Chem.Atom("*") | |
if symbol[0] == 'R' and symbol[1:].isdigit(): | |
atom.SetIsotope(int(symbol[1:])) | |
Chem.SetAtomAlias(atom, symbol) | |
elif symbol in ABBREVIATIONS: | |
atom = Chem.Atom("*") | |
Chem.SetAtomAlias(atom, symbol) | |
else: | |
try: # try to get SMILES of atom | |
atom = Chem.AtomFromSmiles(symbols[i]) | |
atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) | |
except: # otherwise, abbreviation or condensed formula | |
atom = Chem.Atom("*") | |
Chem.SetAtomAlias(atom, symbol) | |
if atom.GetSymbol() == '*': | |
atom.SetProp('molFileAlias', symbol) | |
idx = mol.AddAtom(atom) | |
assert idx == i | |
ids.append(idx) | |
for i in range(n): | |
for j in range(i + 1, n): | |
if edges[i][j] == 1: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) | |
elif edges[i][j] == 2: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE) | |
elif edges[i][j] == 3: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE) | |
elif edges[i][j] == 4: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC) | |
elif edges[i][j] == 5: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) | |
mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE) | |
elif edges[i][j] == 6: | |
mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) | |
mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH) | |
pred_smiles = '<invalid>' | |
try: | |
# TODO: move to an util function | |
if image is not None: | |
height, width, _ = image.shape | |
ratio = width / height | |
coords = [[x * ratio * 10, y * 10] for x, y in coords] | |
mol = _verify_chirality(mol, coords, symbols, edges, debug) | |
# molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates. | |
# TODO: make sure molblock has the abbreviation information | |
pred_molblock = Chem.MolToMolBlock(mol) | |
pred_smiles, mol = _expand_functional_group(mol, {}, debug) | |
success = True | |
except Exception as e: | |
if debug: | |
print(traceback.format_exc()) | |
pred_molblock = '' | |
success = False | |
if debug: | |
return pred_smiles, pred_molblock, mol, success | |
return pred_smiles, pred_molblock, success | |
def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16): | |
if images is None: | |
args_zip = zip(coords, symbols, edges) | |
else: | |
args_zip = zip(coords, symbols, edges, images) | |
if num_workers <= 1: | |
results = itertools.starmap(_convert_graph_to_smiles, args_zip) | |
results = list(results) | |
else: | |
with multiprocessing.Pool(num_workers) as p: | |
results = p.starmap(_convert_graph_to_smiles, args_zip, chunksize=128) | |
smiles_list, molblock_list, success = zip(*results) | |
r_success = np.mean(success) | |
return smiles_list, molblock_list, r_success | |
def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False): | |
if type(smiles) is not str or smiles == '': | |
return '', False | |
mol = None | |
pred_molblock = '' | |
try: | |
pred_smiles = smiles | |
pred_smiles, mappings = _replace_functional_group(pred_smiles) | |
if coords is not None and symbols is not None and edges is not None: | |
pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '') | |
mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False)) | |
mol = _verify_chirality(mol, coords, symbols, edges, debug) | |
else: | |
mol = Chem.MolFromSmiles(pred_smiles, sanitize=False) | |
# pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True) | |
if molblock: | |
pred_molblock = Chem.MolToMolBlock(mol) | |
pred_smiles, mol = _expand_functional_group(mol, mappings) | |
success = True | |
except Exception as e: | |
if debug: | |
print(traceback.format_exc()) | |
pred_smiles = smiles | |
pred_molblock = '' | |
success = False | |
if debug: | |
return pred_smiles, pred_molblock, mol, success | |
return pred_smiles, pred_molblock, success | |
def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16): | |
with multiprocessing.Pool(num_workers) as p: | |
if coords is not None and symbols is not None and edges is not None: | |
results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128) | |
else: | |
results = p.map(_postprocess_smiles, smiles, chunksize=128) | |
smiles_list, molblock_list, success = zip(*results) | |
r_success = np.mean(success) | |
return smiles_list, molblock_list, r_success | |
def _keep_main_molecule(smiles, debug=False): | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
frags = Chem.GetMolFrags(mol, asMols=True) | |
if len(frags) > 1: | |
num_atoms = [m.GetNumAtoms() for m in frags] | |
main_mol = frags[np.argmax(num_atoms)] | |
smiles = Chem.MolToSmiles(main_mol) | |
except Exception as e: | |
if debug: | |
print(traceback.format_exc()) | |
return smiles | |
def keep_main_molecule(smiles, num_workers=16): | |
with multiprocessing.Pool(num_workers) as p: | |
results = p.map(_keep_main_molecule, smiles, chunksize=128) | |
return results | |