File size: 3,064 Bytes
2e605bf
 
 
 
 
7198503
89e8857
2e605bf
7198503
2e605bf
 
 
 
89e8857
 
 
 
58d00bc
89e8857
 
 
 
2e605bf
7198503
89e8857
7198503
 
2e605bf
89e8857
2e605bf
7198503
 
89e8857
7198503
89e8857
2e605bf
 
 
7198503
 
2e605bf
 
89e8857
7198503
 
 
 
2e605bf
 
 
 
 
89e8857
 
 
 
 
 
2e605bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c5eb2
2e605bf
 
7198503
2e605bf
7198503
 
89e8857
 
 
 
 
 
 
 
 
 
2e605bf
 
 
 
 
 
7198503
2e605bf
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
    HuggingFaceSeq2SeqGenerator,
    HuggingFaceGenerationAlgorithm,
)
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

task2prefix = {
    "forward": "Predict the product of the following reaction: ",
    "retrosynthesis": "Predict the reaction that produces the following product: ",
    "paragraph to actions": "Which actions are described in the following paragraph: ",
    "molecular captioning": "Caption the following smile: ",
    "text-conditional de novo generation": "Write in SMILES the described molecule: ",
}


def run_inference(
    model_name_or_path: str,
    task: str,
    prompt: str,
    num_beams: int,
):
    instruction = task2prefix[task]

    config = HuggingFaceSeq2SeqGenerator(
        algorithm_version=model_name_or_path,
        prefix=instruction,
        prompt=prompt,
        num_beams=num_beams,
    )

    model = HuggingFaceGenerationAlgorithm(config)
    tokenizer = AutoTokenizer.from_pretrained("t5-small")

    text = list(model.sample(1))[0]

    text = text.replace(instruction + prompt, "")
    text = text.split(tokenizer.eos_token)[0]
    text = text.replace(tokenizer.pad_token, "")
    text = text.strip()

    return text


if __name__ == "__main__":

    models = [
        "text-chem-t5-small-standard",
        "text-chem-t5-small-augm",
        "text-chem-t5-base-standard",
        "text-chem-t5-base-augm",
    ]

    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Text+chem-T5 model",
        inputs=[
            gr.Dropdown(
                models,
                label="Language model",
                value="text-chem-t5-base-augm",
            ),
            gr.Radio(
                choices=[
                    "forward",
                    "retrosynthesis",
                    "paragraph to actions",
                    "molecular captioning",
                    "text-conditional de novo generation",
                ],
                label="Task",
                value="paragraph to actions",
            ),
            gr.Textbox(
                label="Text prompt",
                placeholder="I'm a stochastic parrot.",
                lines=1,
            ),
            gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)