File size: 1,623 Bytes
78e0383
14da265
 
78e0383
 
 
 
 
 
 
 
14da265
 
78e0383
 
14da265
78e0383
 
14da265
 
 
 
78e0383
 
 
 
 
 
 
 
 
14da265
78e0383
14da265
78e0383
14da265
 
 
 
 
 
 
78e0383
14da265
 
 
 
 
78e0383
14da265
78e0383
9db0a68
78e0383
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import logging
from typing import List
import numpy as np
import mols2grid
import pandas as pd
from rdkit import Chem

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def draw_grid_predict(
    sequences: List[str], properties: np.array, property_names: List[str], domain: str
) -> str:
    """
    Uses mols2grid to draw a HTML grid for the prediction

    Args:
        sequences: Sequences for which properties are predicted.
        properties: Predicted properties. Array of shape (n_samples, n_properties).
        names: List of property names
        domain: Domain of the prediction (molecules or proteins).

    Returns:
        HTML to display
    """

    if domain not in ["Molecules", "Proteins"]:
        raise ValueError(f"Unsupported domain {domain}")

    if domain == "Proteins":
        converter = lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
    else:
        converter = lambda x: x

    smiles = []
    for sequence in sequences:
        try:
            seq = converter(sequence)
            smiles.append(seq)
        except Exception:
            logger.warning(f"Could not draw sequence {seq}")

    result = pd.DataFrame({"SMILES": smiles})
    for i, name in enumerate(property_names):
        result[name] = properties[:, i]
    n_cols = min(3, len(result))
    size = (140, 200) if len(result) > 3 else (600, 700)
    obj = mols2grid.display(
        result,
        tooltip=list(result.keys()),
        subset=["img"] + list(result.keys()),
        height=1100,
        n_cols=n_cols,
        name="Results",
        size=size,
    )
    return obj.data