|
import logging |
|
import pathlib |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.conditional_generation.guacamol import ( |
|
AaeGenerator, |
|
GraphGAGenerator, |
|
GraphMCTSGenerator, |
|
GuacaMolGenerator, |
|
MosesGenerator, |
|
OrganGenerator, |
|
VaeGenerator, |
|
SMILESGAGenerator, |
|
SMILESLSTMHCGenerator, |
|
SMILESLSTMPPOGenerator, |
|
) |
|
from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
|
from utils import draw_grid_generate |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
TITLE = "GuacaMol & MOSES" |
|
|
|
CONFIG_FACTORY = { |
|
"Moses - AaeGenerator": AaeGenerator, |
|
"Moses - VaeGenerator": VaeGenerator, |
|
"Moses - OrganGenerator": OrganGenerator, |
|
"GuacaMol - GraphGAGenerator": GraphGAGenerator, |
|
"GuacaMol - GraphMCTSGenerator": GraphMCTSGenerator, |
|
"GuacaMol - SMILESLSTMHCGenerator": SMILESLSTMHCGenerator, |
|
"GuacaMol - SMILESLSTMPPOGenerator": SMILESLSTMPPOGenerator, |
|
"GuacaMol - SMILESGAGenerator": SMILESGAGenerator, |
|
} |
|
|
|
CONFIG_FACTORY = { |
|
"AaeGenerator": AaeGenerator, |
|
"VaeGenerator": VaeGenerator, |
|
"OrganGenerator": OrganGenerator, |
|
} |
|
MODEL_FACTORY = {"Moses": MosesGenerator, "GuacaMol": GuacaMolGenerator} |
|
|
|
|
|
def run_inference( |
|
algorithm_version: str, |
|
length: int, |
|
|
|
|
|
|
|
|
|
number_of_samples: int, |
|
): |
|
config_class = CONFIG_FACTORY[algorithm_version] |
|
|
|
family = "Moses" |
|
model_class = MODEL_FACTORY[family] |
|
|
|
if family == "Moses": |
|
kwargs = {"n_samples": number_of_samples, "max_len": length} |
|
elif family == "GuacaMol": |
|
kwargs = { |
|
"population_size": population_size, |
|
"random_start": random_start, |
|
"patience": patience, |
|
"generations": generations, |
|
} |
|
if "MCTS" in algorithm_version: |
|
kwargs.pop("random_start") |
|
if "LSTMHC" in algorithm_version: |
|
kwargs["max_len"] = length |
|
kwargs.pop("population_size") |
|
kwargs.pop("patience") |
|
kwargs.pop("generations") |
|
if "LSTMPPO" in algorithm_version: |
|
kwargs = {} |
|
else: |
|
raise ValueError(f"Unknown family {family}") |
|
|
|
config = config_class(**kwargs) |
|
|
|
model = model_class(configuration=config, target={}) |
|
samples = list(model.sample(number_of_samples)) |
|
|
|
return draw_grid_generate(seeds=[], samples=samples, n_cols=5) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_algos = ApplicationsRegistry.list_available() |
|
guacamol_algos = [ |
|
"GuacaMol - " + x["algorithm_application"] |
|
for x in list(filter(lambda x: "GuacaMol" in x["algorithm_name"], all_algos)) |
|
] |
|
moses_algos = [ |
|
"Moses - " + x["algorithm_application"] |
|
for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) |
|
] |
|
algos = guacamol_algos + moses_algos |
|
|
|
|
|
algos = [ |
|
x["algorithm_application"] |
|
for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) |
|
] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
|
"" |
|
) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="MOSES", |
|
inputs=[ |
|
gr.Dropdown(algos, label="Algorithm version", value="AaeGenerator"), |
|
gr.Slider( |
|
minimum=5, maximum=500, value=100, label="Sequence length", step=1 |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Slider( |
|
minimum=1, maximum=50, value=5, label="Number of samples", step=1 |
|
), |
|
], |
|
outputs=gr.HTML(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|