Spaces:
Running
Running
File size: 1,623 Bytes
78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 14da265 78e0383 9db0a68 78e0383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import logging
from typing import List
import numpy as np
import mols2grid
import pandas as pd
from rdkit import Chem
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def draw_grid_predict(
sequences: List[str], properties: np.array, property_names: List[str], domain: str
) -> str:
"""
Uses mols2grid to draw a HTML grid for the prediction
Args:
sequences: Sequences for which properties are predicted.
properties: Predicted properties. Array of shape (n_samples, n_properties).
names: List of property names
domain: Domain of the prediction (molecules or proteins).
Returns:
HTML to display
"""
if domain not in ["Molecules", "Proteins"]:
raise ValueError(f"Unsupported domain {domain}")
if domain == "Proteins":
converter = lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
else:
converter = lambda x: x
smiles = []
for sequence in sequences:
try:
seq = converter(sequence)
smiles.append(seq)
except Exception:
logger.warning(f"Could not draw sequence {seq}")
result = pd.DataFrame({"SMILES": smiles})
for i, name in enumerate(property_names):
result[name] = properties[:, i]
n_cols = min(3, len(result))
size = (140, 200) if len(result) > 3 else (600, 700)
obj = mols2grid.display(
result,
tooltip=list(result.keys()),
subset=["img"] + list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data
|