Spaces:
GT4SD
/
Runtime error

File size: 4,565 Bytes
34ae1d8
 
 
 
 
8ae3405
 
 
 
 
 
 
 
 
 
 
 
34ae1d8
8ae3405
34ae1d8
 
 
 
 
8ae3405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34ae1d8
 
 
 
8ae3405
 
 
 
 
34ae1d8
 
8ae3405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34ae1d8
 
8ae3405
34ae1d8
 
 
 
 
 
8ae3405
 
 
 
 
 
 
 
 
 
 
34ae1d8
8ae3405
 
34ae1d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae3405
34ae1d8
8ae3405
 
 
34ae1d8
8ae3405
 
 
 
 
 
34ae1d8
8ae3405
34ae1d8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import logging
import pathlib

import gradio as gr
import pandas as pd
from gt4sd.algorithms.conditional_generation.guacamol import (
    AaeGenerator,
    GraphGAGenerator,
    GraphMCTSGenerator,
    GuacaMolGenerator,
    MosesGenerator,
    OrganGenerator,
    VaeGenerator,
    SMILESGAGenerator,
    SMILESLSTMHCGenerator,
    SMILESLSTMPPOGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry

from utils import draw_grid_generate

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

TITLE = "GuacaMol & MOSES"

CONFIG_FACTORY = {
    "Moses - AaeGenerator": AaeGenerator,
    "Moses - VaeGenerator": VaeGenerator,
    "Moses - OrganGenerator": OrganGenerator,
    "GuacaMol - GraphGAGenerator": GraphGAGenerator,
    "GuacaMol - GraphMCTSGenerator": GraphMCTSGenerator,
    "GuacaMol - SMILESLSTMHCGenerator": SMILESLSTMHCGenerator,
    "GuacaMol - SMILESLSTMPPOGenerator": SMILESLSTMPPOGenerator,
    "GuacaMol - SMILESGAGenerator": SMILESGAGenerator,
}
# OVERWRITE
CONFIG_FACTORY = {
    "AaeGenerator": AaeGenerator,
    "VaeGenerator": VaeGenerator,
    "OrganGenerator": OrganGenerator,
}
MODEL_FACTORY = {"Moses": MosesGenerator, "GuacaMol": GuacaMolGenerator}


def run_inference(
    algorithm_version: str,
    length: int,
    # population_size: int,
    # random_start: bool,
    # patience: int,
    # generations: int,
    number_of_samples: int,
):
    config_class = CONFIG_FACTORY[algorithm_version]
    # family = algorithm_version.split(" - ")[0]
    family = "Moses"
    model_class = MODEL_FACTORY[family]

    if family == "Moses":
        kwargs = {"n_samples": number_of_samples, "max_len": length}
    elif family == "GuacaMol":
        kwargs = {
            "population_size": population_size,
            "random_start": random_start,
            "patience": patience,
            "generations": generations,
        }
        if "MCTS" in algorithm_version:
            kwargs.pop("random_start")
        if "LSTMHC" in algorithm_version:
            kwargs["max_len"] = length
            kwargs.pop("population_size")
            kwargs.pop("patience")
            kwargs.pop("generations")
        if "LSTMPPO" in algorithm_version:
            kwargs = {}
    else:
        raise ValueError(f"Unknown family {family}")

    config = config_class(**kwargs)

    model = model_class(configuration=config, target={})
    samples = list(model.sample(number_of_samples))

    return draw_grid_generate(seeds=[], samples=samples, n_cols=5)


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    guacamol_algos = [
        "GuacaMol - " + x["algorithm_application"]
        for x in list(filter(lambda x: "GuacaMol" in x["algorithm_name"], all_algos))
    ]
    moses_algos = [
        "Moses - " + x["algorithm_application"]
        for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos))
    ]
    algos = guacamol_algos + moses_algos

    # Overwrite to have only Moses
    algos = [
        x["algorithm_application"]
        for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="MOSES",
        inputs=[
            gr.Dropdown(algos, label="Algorithm version", value="AaeGenerator"),
            gr.Slider(
                minimum=5, maximum=500, value=100, label="Sequence length", step=1
            ),
            # gr.Slider(
            #     minimum=5, maximum=500, value=100, label="Population size", step=1
            # ),
            # gr.Radio(choices=[True, False], label="Random start", value=False),
            # gr.Slider(minimum=1, maximum=10, value=4, label="Patience"),
            # gr.Slider(minimum=1, maximum=10, value=2, label="Generations"),
            gr.Slider(
                minimum=1, maximum=50, value=5, label="Number of samples", step=1
            ),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)