Spaces:
Running
Running
"""SMILES utilities.""" | |
import os | |
import regex | |
import logging | |
import matplotlib | |
import matplotlib as mpl | |
import matplotlib.cm as cm | |
import matplotlib.pyplot as plt | |
from io import StringIO | |
from operator import itemgetter | |
from typing import Callable, Iterable, Tuple | |
from matplotlib.ticker import FormatStrFormatter, ScalarFormatter | |
from rdkit import Chem | |
from rdkit.Chem.Draw import rdMolDraw2D | |
from configuration import SMILES_LANGUAGE, SMILES_TOKENIZE_FN | |
logger = logging.getLogger(__name__) | |
# NOTE: avoid segfaults in matplotlib | |
matplotlib.use("Agg") | |
MOLECULE_TOKENS = set(SMILES_LANGUAGE.token_to_index.keys()) | |
NON_ATOM_REGEX = regex.compile(r"^(\d|\%\d+|\p{P}+|\p{Math}+)$") | |
NON_ATOM_TOKENS = set( | |
[token for token in MOLECULE_TOKENS if NON_ATOM_REGEX.match(token)] | |
) | |
CMAP = cm.Oranges | |
COLOR_NORMALIZERS = {"linear": mpl.colors.Normalize, "logarithmic": mpl.colors.LogNorm} | |
ATOM_RADII = float(os.environ.get("PACCMANN_ATOM_RADII", 0.5)) | |
SVG_WIDTH = int(os.environ.get("PACCMANN_SVG_WIDTH", 400)) | |
SVG_HEIGHT = int(os.environ.get("PACCMANN_SVG_HEIGHT", 200)) | |
COLOR_NORMALIZATION = os.environ.get("PACCMANN_COLOR_NORMALIZATION", "logarithmic") | |
def validate_smiles(smiles: str) -> bool: | |
""" | |
Validate a SMILES. | |
Args: | |
smiles (str): a SMILES string. | |
Returns: | |
bool: flag indicating whether the SMILES is a valid molecule. | |
""" | |
molecule = Chem.MolFromSmiles(smiles) | |
return not (molecule is None) | |
def canonicalize_smiles(smiles: str) -> str: | |
""" | |
Canonicalize a SMILES. | |
Args: | |
smiles (str): a SMILES string. | |
Returns: | |
str: the canonicalized SMILES. | |
""" | |
molecule = Chem.MolFromSmiles(smiles) | |
return Chem.MolToSmiles(molecule) | |
def remove_housekeeping_from_tokens_and_smiles_attention( | |
tokens: Iterable[str], smiles_attention: Iterable[float] | |
) -> Tuple[Iterable[str], Iterable[float]]: | |
""" | |
Remove housekeeping tokens and corresponding attention weights. | |
Args: | |
tokens (Iterable[str]): tokens obtained from the SMILES. | |
smiles_attention (Iterable[float]): SMILES attention. | |
Returns: | |
Tuple[Iterable[str], Iterable[float]]: a tuple containing the filtered | |
tokens and attention values. | |
""" | |
to_keep = [index for index, token in enumerate(tokens) if token in MOLECULE_TOKENS] | |
return ( | |
list(itemgetter(*to_keep)(tokens)), | |
list(itemgetter(*to_keep)(smiles_attention)), | |
) | |
def _get_index_and_colors( | |
values: Iterable[float], | |
tokens: Iterable[str], | |
predicate: Callable[[tuple], bool], | |
color_mapper: cm.ScalarMappable, | |
) -> Tuple[Iterable[int], Iterable[tuple]]: | |
""" | |
Get index and RGB colors from a color map using a rule. | |
Args: | |
values (Iterable[float]): values associated to tokens. | |
tokens (Iterable[str]): tokens. | |
predicate (Callable[[tuple], bool]): a predicate that acts on a tuple | |
of (value, object). | |
color_mapper (cm.ScalarMappable): a color mapper. | |
Returns: | |
Tuple[Iterable[int], Iterable[tuple]]: tuple with indexes and RGB | |
colors associated to the given index. | |
""" | |
indices = [] | |
colors = {} | |
for index, value in enumerate( | |
map(lambda t: t[0], filter(lambda t: predicate(t), zip(values, tokens))) | |
): | |
indices.append(index) | |
colors[index] = color_mapper.to_rgba(value) | |
return indices, colors | |
def smiles_attention_to_svg( | |
smiles: str, smiles_attention: Iterable[float] | |
) -> Tuple[str, str]: | |
""" | |
Generate an svg of the molecule highlighiting SMILES attention. | |
Args: | |
smiles (str): SMILES representing a molecule. | |
smiles_attention (Iterable[float]): SMILES attention. | |
Returns: | |
Tuple[str, str]: drawing, colorbar | |
the svg of the molecule highlighiting SMILES attention | |
and the svg displaying the colorbar | |
""" | |
# remove padding | |
logger.debug("SMILES attention:\n{}.".format(smiles_attention)) | |
logger.debug( | |
"SMILES attention range: [{},{}].".format( | |
min(smiles_attention), max(smiles_attention) | |
) | |
) | |
# get the molecule | |
molecule = Chem.MolFromSmiles(smiles) | |
tokens = [ | |
SMILES_LANGUAGE.index_to_token[token_index] | |
for token_index in SMILES_TOKENIZE_FN(smiles) | |
] | |
logger.debug("SMILES tokens:{}.".format(tokens)) | |
tokens, smiles_attention = remove_housekeeping_from_tokens_and_smiles_attention( | |
tokens, smiles_attention | |
) # yapf:disable | |
logger.debug( | |
"tokens and SMILES attention after removal:\n{}\n{}.".format( | |
tokens, smiles_attention | |
) | |
) | |
logger.debug( | |
"SMILES attention range after padding removal: [{},{}].".format( | |
min(smiles_attention), max(smiles_attention) | |
) | |
) | |
# define a color map | |
normalize = COLOR_NORMALIZERS.get(COLOR_NORMALIZATION, mpl.colors.LogNorm)( | |
vmin=min(smiles_attention), vmax=min(1.0, 2 * max(smiles_attention)) | |
) | |
color_mapper = cm.ScalarMappable(norm=normalize, cmap=CMAP) | |
# get atom colors | |
highlight_atoms, highlight_atom_colors = _get_index_and_colors( | |
smiles_attention, tokens, lambda t: t[1] not in NON_ATOM_TOKENS, color_mapper | |
) | |
logger.debug("atom colors:\n{}.".format(highlight_atom_colors)) | |
# get bond colors | |
highlight_bonds, highlight_bond_colors = _get_index_and_colors( | |
smiles_attention, tokens, lambda t: t[1] in NON_ATOM_TOKENS, color_mapper | |
) | |
logger.debug("bond colors:\n{}.".format(highlight_bond_colors)) | |
# add coordinates | |
logger.debug("compute 2D coordinates") | |
Chem.rdDepictor.Compute2DCoords(molecule) | |
# draw the molecule | |
logger.debug("get a drawer") | |
drawer = rdMolDraw2D.MolDraw2DSVG(SVG_WIDTH, SVG_HEIGHT) | |
logger.debug("draw the molecule") | |
drawer.DrawMolecule( | |
molecule, | |
highlightAtoms=highlight_atoms, | |
highlightAtomColors=highlight_atom_colors, | |
highlightBonds=highlight_bonds, | |
highlightBondColors=highlight_bond_colors, | |
highlightAtomRadii={index: ATOM_RADII for index in highlight_atoms}, | |
) | |
logger.debug("finish drawing") | |
drawer.FinishDrawing() | |
# the drawn molecule as str | |
logger.debug("drawing to string") | |
drawing = drawer.GetDrawingText().replace("\n", " ") | |
# the respective colorbar | |
logger.debug("draw the colorbar") | |
fig, ax = plt.subplots(figsize=(0.5, 6)) | |
mpl.colorbar.ColorbarBase( | |
ax, | |
cmap=CMAP, | |
norm=normalize, | |
orientation="vertical", | |
extend="both", | |
extendrect=True, | |
) | |
# instead of LogFormatterSciNotation | |
logger.debug("format the colorbar") | |
ax.yaxis.set_minor_formatter(ScalarFormatter()) | |
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f")) # fixes 0.1, 0.20 | |
# the colorbar svg as str | |
logger.debug("colorbar to string") | |
file_like = StringIO() | |
plt.savefig(file_like, format="svg", bbox_inches="tight") | |
colorbar = file_like.getvalue().replace("\n", " ") | |
plt.close(fig) | |
return drawing, colorbar | |