File size: 1,868 Bytes
5da68a0
 
b68abc1
5da68a0
b68abc1
5da68a0
 
 
b68abc1
5da68a0
b68abc1
 
5da68a0
 
 
 
 
d2d894e
5da68a0
b68abc1
 
 
 
d2d894e
b68abc1
5da68a0
b68abc1
 
d2d894e
5da68a0
b68abc1
 
5da68a0
b68abc1
5da68a0
 
 
 
 
 
 
d2d894e
 
 
 
5da68a0
 
 
 
 
 
 
 
b68abc1
5da68a0
b68abc1
 
 
5da68a0
b68abc1
5da68a0
 
b68abc1
5da68a0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import logging
import pathlib
import pickle
import gradio as gr
from typing import Dict, Any
import pandas as pd
from gt4sd.algorithms.generation.diffusion import (
    DiffusersGenerationAlgorithm,
    GeoDiffGenerator,
)
from utils import draw_grid_generate
from rdkit import Chem

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(prompt_file: str, prompt_id: int, number_of_samples: int):

    # Read file:
    with open(prompt_file.name, "rb") as f:
        prompts = pickle.load(f)

    if all(isinstance(x, int) for x in prompts.keys()):
        prompt = prompts[prompt_id]
    else:
        prompt = prompts

    config = GeoDiffGenerator(prompt=prompt)
    model = DiffusersGenerationAlgorithm(config)
    results = list(model.sample(number_of_samples))
    smiles = [Chem.MolToSmiles(m) for m in results]

    return draw_grid_generate(samples=smiles, n_cols=5)


if __name__ == "__main__":

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = [
        [metadata_root.joinpath("mol_dct.pkl"), 0, 2],
        [metadata_root.joinpath("mol_dct.pkl"), 1, 2],
    ]

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="GeoDiff",
        inputs=[
            gr.File(file_types=[".pkl"], label="GeoDiff prompt"),
            gr.Number(value=0, label="Prompt ID", precision=0),
            gr.Slider(minimum=1, maximum=5, value=2, label="Number of samples", step=1),
        ],
        outputs=gr.HTML(label="Output"),
        article=article,
        description=description,
        examples=examples,
    )
    demo.launch(debug=True, show_error=True)