import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.conditional_generation.guacamol import ( AaeGenerator, GraphGAGenerator, GraphMCTSGenerator, GuacaMolGenerator, MosesGenerator, OrganGenerator, VaeGenerator, SMILESGAGenerator, SMILESLSTMHCGenerator, SMILESLSTMPPOGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry from utils import draw_grid_generate logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) TITLE = "GuacaMol & MOSES" CONFIG_FACTORY = { "Moses - AaeGenerator": AaeGenerator, "Moses - VaeGenerator": VaeGenerator, "Moses - OrganGenerator": OrganGenerator, "GuacaMol - GraphGAGenerator": GraphGAGenerator, "GuacaMol - GraphMCTSGenerator": GraphMCTSGenerator, "GuacaMol - SMILESLSTMHCGenerator": SMILESLSTMHCGenerator, "GuacaMol - SMILESLSTMPPOGenerator": SMILESLSTMPPOGenerator, "GuacaMol - SMILESGAGenerator": SMILESGAGenerator, } # OVERWRITE CONFIG_FACTORY = { "AaeGenerator": AaeGenerator, "VaeGenerator": VaeGenerator, "OrganGenerator": OrganGenerator, } MODEL_FACTORY = {"Moses": MosesGenerator, "GuacaMol": GuacaMolGenerator} def run_inference( algorithm_version: str, length: int, # population_size: int, # random_start: bool, # patience: int, # generations: int, number_of_samples: int, ): config_class = CONFIG_FACTORY[algorithm_version] # family = algorithm_version.split(" - ")[0] family = "Moses" model_class = MODEL_FACTORY[family] if family == "Moses": kwargs = {"n_samples": number_of_samples, "max_len": length} elif family == "GuacaMol": kwargs = { "population_size": population_size, "random_start": random_start, "patience": patience, "generations": generations, } if "MCTS" in algorithm_version: kwargs.pop("random_start") if "LSTMHC" in algorithm_version: kwargs["max_len"] = length kwargs.pop("population_size") kwargs.pop("patience") kwargs.pop("generations") if "LSTMPPO" in algorithm_version: kwargs = {} else: raise ValueError(f"Unknown family {family}") config = config_class(**kwargs) model = model_class(configuration=config, target={}) samples = list(model.sample(number_of_samples)) return draw_grid_generate(seeds=[], samples=samples, n_cols=5) if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() guacamol_algos = [ "GuacaMol - " + x["algorithm_application"] for x in list(filter(lambda x: "GuacaMol" in x["algorithm_name"], all_algos)) ] moses_algos = [ "Moses - " + x["algorithm_application"] for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) ] algos = guacamol_algos + moses_algos # Overwrite to have only Moses algos = [ x["algorithm_application"] for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) ] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( "" ) with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="MOSES", inputs=[ gr.Dropdown(algos, label="Algorithm version", value="AaeGenerator"), gr.Slider( minimum=5, maximum=500, value=100, label="Sequence length", step=1 ), # gr.Slider( # minimum=5, maximum=500, value=100, label="Population size", step=1 # ), # gr.Radio(choices=[True, False], label="Random start", value=False), # gr.Slider(minimum=1, maximum=10, value=4, label="Patience"), # gr.Slider(minimum=1, maximum=10, value=2, label="Generations"), gr.Slider( minimum=1, maximum=50, value=5, label="Number of samples", step=1 ), ], outputs=gr.HTML(label="Output"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)