protein_properties / utils.py
jannisborn's picture
update
453d7ec unverified
raw
history blame contribute delete
No virus
1.7 kB
import logging
from typing import List
import numpy as np
import mols2grid
import pandas as pd
from rdkit import Chem
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def draw_grid_predict(
sequences: List[str],
properties: np.array,
property_names: List[str],
domain: str,
n_cols: int = -1,
) -> str:
"""
Uses mols2grid to draw a HTML grid for the prediction
Args:
sequences: Sequences for which properties are predicted.
properties: Predicted properties. Array of shape (n_samples, n_properties).
names: List of property names
domain: Domain of the prediction (molecules or proteins).
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
if domain == "Proteins":
converter = lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
else:
converter = lambda x: x
smiles = []
for sequence in sequences:
try:
seq = converter(sequence)
smiles.append(seq)
except Exception:
logger.warning(f"Could not draw sequence {seq}")
result = pd.DataFrame({"SMILES": smiles})
if domain == "Proteins":
result["Seqs"] = sequences
for i, name in enumerate(property_names):
result[name] = properties[:, i]
if n_cols == -1:
n_cols = min(3, len(result))
size = (250, 200) if len(result) > 3 else (600, 700)
obj = mols2grid.display(
result,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data