Spaces:
Running
Running
import logging | |
from typing import List | |
import numpy as np | |
import mols2grid | |
import pandas as pd | |
from rdkit import Chem | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
def draw_grid_predict( | |
sequences: List[str], properties: np.array, property_names: List[str], domain: str | |
) -> str: | |
""" | |
Uses mols2grid to draw a HTML grid for the prediction | |
Args: | |
sequences: Sequences for which properties are predicted. | |
properties: Predicted properties. Array of shape (n_samples, n_properties). | |
names: List of property names | |
domain: Domain of the prediction (molecules or proteins). | |
Returns: | |
HTML to display | |
""" | |
if domain not in ["Molecules", "Proteins"]: | |
raise ValueError(f"Unsupported domain {domain}") | |
if domain == "Proteins": | |
converter = lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x)) | |
else: | |
converter = lambda x: x | |
smiles = [] | |
for sequence in sequences: | |
try: | |
seq = converter(sequence) | |
smiles.append(seq) | |
except Exception: | |
logger.warning(f"Could not draw sequence {seq}") | |
result = pd.DataFrame({"SMILES": smiles}) | |
for i, name in enumerate(property_names): | |
result[name] = properties[:, i] | |
n_cols = min(3, len(result)) | |
size = (140, 200) if len(result) > 3 else (600, 700) | |
obj = mols2grid.display( | |
result, | |
tooltip=list(result.keys()), | |
height=1100, | |
n_cols=n_cols, | |
name="Results", | |
size=size, | |
) | |
return obj.data | |