File size: 1,390 Bytes
c703bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging

import mols2grid
import pandas as pd
from rdkit import Chem
from terminator.selfies import decoder

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
    """
    Uses mols2grid to draw a HTML grid for the prediction

    Args:
        prediction: Predicted sequence.
        target: Target molecule
        domain: Domain of the prediction (molecules or proteins)

    Returns:
        HTML to display
    """

    if domain not in ["Molecules", "Proteins"]:
        raise ValueError(f"Unsupported domain {domain}")

    seq = target.split("|")[-1]
    converter = (
        decoder
        if domain == "Molecules"
        else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
    )
    try:
        seq = converter(seq)
    except Exception:
        logger.warning(f"Could not draw sequence {seq}")

    result = {"SMILES": [seq], "Name": ["Target"]}
    # Add properties
    for prop in prediction.split("<")[1:]:
        result[
            prop.split(">")[0]
        ] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}"
    result_df = pd.DataFrame(result)
    obj = mols2grid.display(
        result_df,
        tooltip=list(result.keys()),
        height=900,
        n_cols=1,
        name="Results",
        size=(600, 700),
    )
    return obj.data