File size: 4,692 Bytes
78e0383
 
 
 
14da265
78e0383
14da265
 
 
78e0383
 
 
 
a0a4d74
1f4fe6e
 
a0a4d74
 
 
 
 
 
 
 
 
 
 
 
1f4fe6e
a0a4d74
 
14da265
 
a0a4d74
78e0383
14da265
 
 
 
 
 
 
 
1f4fe6e
 
 
14da265
 
 
78e0383
f9d0119
1f4fe6e
7cb72bc
14da265
5f6f2d9
14da265
5f6f2d9
14da265
 
 
 
 
 
 
 
 
 
 
78e0383
14da265
 
 
 
 
78e0383
 
 
 
14da265
 
 
 
 
78e0383
a0a4d74
 
1f4fe6e
 
 
a0a4d74
78e0383
 
 
14da265
5f6f2d9
14da265
 
 
 
 
 
78e0383
14da265
78e0383
14da265
78e0383
 
 
14da265
 
78e0383
2355d36
78e0383
14da265
 
 
78e0383
14da265
 
 
78e0383
 
 
 
 
14da265
78e0383
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
import pathlib

import gradio as gr
import numpy as np
import pandas as pd
from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY

from utils import draw_grid_predict

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

MOLFORMER_VERSIONS = {
    "molformer_classification": ["bace", "bbbp", "hiv"],
    "molformer_regression": [
        "alpha",
        "cv",
        "g298",
        "gap",
        "h298",
        "homo",
        "lipo",
        "lumo",
        "mu",
        "r2",
        "u0",
    ],
    "molformer_multitask_classification": ["clintox", "sider", "tox21"],
}

REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"]
REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"])
REMOVE.extend(MOLFORMER_VERSIONS.keys())

MODEL_PROP_DESCRIPTION = {
    "Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53",
    "Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications",
    "Clintox": "FDA approval, Clinical trial failure",
}


def main(property: str, smiles: str, smiles_file: str):
    if "Molformer" in property:
        version = property.split(" ")[-1].split("(")[-1].split(")")[0]
        property = property.split(" ")[0]
    algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()]
    kwargs = (
        {"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {}
    )
    if property.lower() in MOLFORMER_VERSIONS.keys():
        kwargs["algorithm_version"] = version

    model = algo(config(**kwargs))
    if smiles != "" and smiles_file is not None:
        raise ValueError("Pass either smiles or smiles_file, not both.")
    elif smiles != "":
        smiles = [smiles]
    elif smiles_file is not None:
        smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist()
    props = np.array(list(map(model, smiles))).round(2)

    # Expand to 2D array if needed
    if len(props.shape) == 1:
        props = np.expand_dims(np.array(props), -1)

    if property in MODEL_PROP_DESCRIPTION.keys():
        property_names = MODEL_PROP_DESCRIPTION[property].split(",")
    else:
        property_names = [property]

    return draw_grid_predict(
        smiles, props, property_names=property_names, domain="Molecules"
    )


if __name__ == "__main__":
    # Preparation (retrieve all available algorithms)
    properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
    for prop in REMOVE:
        prop_to_idx = dict(zip(properties, range(len(properties))))
        properties.pop(prop_to_idx[prop])
    properties = list(map(lambda x: x.capitalize(), properties))

    # MolFormer options
    for key in MOLFORMER_VERSIONS.keys():
        properties.extend(
            [f"{key.capitalize()} ({version})" for version in MOLFORMER_VERSIONS[key]]
        )

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = [
        ["Qed", "", metadata_root.joinpath("examples.smi")],
        [
            "Esol",
            "CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1",
            None,
        ],
    ]

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=main,
        title="Molecular properties",
        inputs=[
            gr.Dropdown(properties, label="Property", value="Scscore"),
            gr.Textbox(
                label="Single SMILES",
                placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
                lines=1,
            ),
            gr.File(
                file_types=[".smi"],
                label="Multiple SMILES (tab-separated, `.smi` file)",
            ),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples,
    )
    demo.launch(debug=True, show_error=True)