File size: 1,868 Bytes
5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 d2d894e 5da68a0 b68abc1 d2d894e b68abc1 5da68a0 b68abc1 d2d894e 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 d2d894e 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 b68abc1 5da68a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import logging
import pathlib
import pickle
import gradio as gr
from typing import Dict, Any
import pandas as pd
from gt4sd.algorithms.generation.diffusion import (
DiffusersGenerationAlgorithm,
GeoDiffGenerator,
)
from utils import draw_grid_generate
from rdkit import Chem
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(prompt_file: str, prompt_id: int, number_of_samples: int):
# Read file:
with open(prompt_file.name, "rb") as f:
prompts = pickle.load(f)
if all(isinstance(x, int) for x in prompts.keys()):
prompt = prompts[prompt_id]
else:
prompt = prompts
config = GeoDiffGenerator(prompt=prompt)
model = DiffusersGenerationAlgorithm(config)
results = list(model.sample(number_of_samples))
smiles = [Chem.MolToSmiles(m) for m in results]
return draw_grid_generate(samples=smiles, n_cols=5)
if __name__ == "__main__":
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
[metadata_root.joinpath("mol_dct.pkl"), 0, 2],
[metadata_root.joinpath("mol_dct.pkl"), 1, 2],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="GeoDiff",
inputs=[
gr.File(file_types=[".pkl"], label="GeoDiff prompt"),
gr.Number(value=0, label="Prompt ID", precision=0),
gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True)
|